Skip to content

Commit e4badc1

Browse files
limit D == 256 to 8 warps
1 parent ffe03eb commit e4badc1

File tree

1 file changed

+47
-33
lines changed

1 file changed

+47
-33
lines changed

ggml-cuda.cu

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7506,8 +7506,8 @@ static __global__ void flash_attn_f32(
75067506
}
75077507
}
75087508

7509-
template<int D, int ncols> // D head size
7510-
__launch_bounds__(ncols == 8 ? (D + D % 32) : 2*D, 1)
7509+
template<int D, int ncols> // D == head size
7510+
__launch_bounds__(ncols == 8 || D > 128 ? D : 2*D, 1)
75117511
static __global__ void flash_attn_ext_f16(
75127512
const char * __restrict__ Q,
75137513
const char * __restrict__ K,
@@ -7545,9 +7545,11 @@ static __global__ void flash_attn_ext_f16(
75457545
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
75467546
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c;
75477547

7548-
constexpr int nwarps = D / frag_m;
7548+
constexpr int nwarps = (D <= 128 || ncols == 8 ? D : D/2) / frag_m;
75497549
constexpr int nthreads = nwarps*WARP_SIZE;
75507550
static_assert(nthreads % D == 0, "nthreads not divisible by D.");
7551+
constexpr int tc_vals_per_iter = nwarps*frag_m;
7552+
static_assert(D % tc_vals_per_iter == 0, "D not divisible by tensor core vals per iter.");
75517553
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
75527554
__builtin_assume(tid < nthreads);
75537555
constexpr int D_padded = D + 8; // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts.
@@ -7608,25 +7610,28 @@ static __global__ void flash_attn_ext_f16(
76087610
const bool has_valid_data = 256 % D == 0 || k_VKQ_0 + frag_m*threadIdx.y < ne11;
76097611

76107612
// Calculate tile of KQ:
7611-
frag_c KQ_c[ncols/frag_n];
76127613
#pragma unroll
7613-
for (int j = 0; j < ncols/frag_n; ++j) {
7614-
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
7615-
}
7616-
if (has_valid_data) {
7614+
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) {
7615+
frag_c KQ_c[ncols/frag_n];
76177616
#pragma unroll
7618-
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
7619-
frag_a_K K_a;
7620-
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
7617+
for (int j = 0; j < ncols/frag_n; ++j) {
7618+
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
7619+
}
7620+
if (has_valid_data) {
76217621
#pragma unroll
7622-
for (int j = 0; j < ncols/frag_n; ++j) {
7623-
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
7622+
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
7623+
frag_a_K K_a;
7624+
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
7625+
#pragma unroll
7626+
for (int j = 0; j < ncols/frag_n; ++j) {
7627+
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
7628+
}
76247629
}
76257630
}
7626-
}
76277631
#pragma unroll
7628-
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
7629-
nvcuda::wmma::store_matrix_sync(KQ + j0*D_padded + frag_m*threadIdx.y, KQ_c[j0/frag_n], D_padded, nvcuda::wmma::mem_col_major);
7632+
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
7633+
nvcuda::wmma::store_matrix_sync(KQ + j0*D_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], D_padded, nvcuda::wmma::mem_col_major);
7634+
}
76307635
}
76317636

76327637
__syncthreads();
@@ -7687,31 +7692,40 @@ static __global__ void flash_attn_ext_f16(
76877692
}
76887693
}
76897694

7690-
frag_c VKQ_c[ncols/frag_n];
7695+
frag_c VKQ_c[D/tc_vals_per_iter][ncols/frag_n];
76917696
#pragma unroll
7692-
for (int j = 0; j < ncols/frag_n; ++j) {
7693-
nvcuda::wmma::fill_fragment(VKQ_c[j], 0.0f);
7694-
}
7695-
7696-
#pragma unroll
7697-
for (int k0 = 0; k0 < D; k0 += 16) {
7698-
if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) {
7699-
break;
7697+
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) {
7698+
#pragma unroll
7699+
for (int j = 0; j < ncols/frag_n; ++j) {
7700+
nvcuda::wmma::fill_fragment(VKQ_c[i_KQ_0/tc_vals_per_iter][j], 0.0f);
77007701
}
77017702

7702-
frag_a_V v_a;
7703-
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k0)*stride_KV + frag_m*threadIdx.y, stride_KV);
7704-
#pragma unroll
7705-
for (int j = 0; j < ncols/frag_n; ++j) {
7706-
nvcuda::wmma::mma_sync(VKQ_c[j], v_a, KQ_b[k0/16][j], VKQ_c[j]);
7703+
#pragma unroll
7704+
for (int k0 = 0; k0 < D; k0 += 16) {
7705+
if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) {
7706+
break;
7707+
}
7708+
7709+
frag_a_V v_a;
7710+
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k0)*stride_KV + i_KQ_0 + frag_m*threadIdx.y, stride_KV);
7711+
#pragma unroll
7712+
for (int j = 0; j < ncols/frag_n; ++j) {
7713+
nvcuda::wmma::mma_sync(VKQ_c[i_KQ_0/tc_vals_per_iter][j], v_a, KQ_b[k0/16][j], VKQ_c[i_KQ_0/tc_vals_per_iter][j]);
7714+
}
77077715
}
77087716
}
77097717

77107718
__syncthreads();
77117719

77127720
#pragma unroll
7713-
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
7714-
nvcuda::wmma::store_matrix_sync(KQ + j0*D_padded + frag_m*threadIdx.y, VKQ_c[j0/frag_n], D_padded, nvcuda::wmma::mem_col_major);
7721+
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) {
7722+
#pragma unroll
7723+
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
7724+
nvcuda::wmma::store_matrix_sync(
7725+
KQ + j0*D_padded + i_KQ_0 + frag_m*threadIdx.y,
7726+
VKQ_c[i_KQ_0/tc_vals_per_iter][j0/frag_n],
7727+
D_padded, nvcuda::wmma::mem_col_major);
7728+
}
77157729
}
77167730

77177731
__syncthreads();
@@ -11453,7 +11467,7 @@ inline void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, const ggml
1145311467
cols_per_block = 8;
1145411468
}
1145511469
const int frag_m = cols_per_block == 8 ? 32 : 16;
11456-
const int nwarps = Q->ne[0] / frag_m;
11470+
const int nwarps = (Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m;
1145711471
const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
1145811472
const dim3 block_dim(WARP_SIZE, nwarps, 1);
1145911473
const size_t shmem = 0;

0 commit comments

Comments
 (0)