@@ -7506,8 +7506,8 @@ static __global__ void flash_attn_f32(
7506
7506
}
7507
7507
}
7508
7508
7509
- template<int D, int ncols> // D head size
7510
- __launch_bounds__(ncols == 8 ? (D + D % 32) : 2*D, 1)
7509
+ template<int D, int ncols> // D == head size
7510
+ __launch_bounds__(ncols == 8 || D > 128 ? D : 2*D, 1)
7511
7511
static __global__ void flash_attn_ext_f16(
7512
7512
const char * __restrict__ Q,
7513
7513
const char * __restrict__ K,
@@ -7545,9 +7545,11 @@ static __global__ void flash_attn_ext_f16(
7545
7545
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
7546
7546
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c;
7547
7547
7548
- constexpr int nwarps = D / frag_m;
7548
+ constexpr int nwarps = (D <= 128 || ncols == 8 ? D : D/2) / frag_m;
7549
7549
constexpr int nthreads = nwarps*WARP_SIZE;
7550
7550
static_assert(nthreads % D == 0, "nthreads not divisible by D.");
7551
+ constexpr int tc_vals_per_iter = nwarps*frag_m;
7552
+ static_assert(D % tc_vals_per_iter == 0, "D not divisible by tensor core vals per iter.");
7551
7553
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
7552
7554
__builtin_assume(tid < nthreads);
7553
7555
constexpr int D_padded = D + 8; // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts.
@@ -7608,25 +7610,28 @@ static __global__ void flash_attn_ext_f16(
7608
7610
const bool has_valid_data = 256 % D == 0 || k_VKQ_0 + frag_m*threadIdx.y < ne11;
7609
7611
7610
7612
// Calculate tile of KQ:
7611
- frag_c KQ_c[ncols/frag_n];
7612
7613
#pragma unroll
7613
- for (int j = 0; j < ncols/frag_n; ++j) {
7614
- nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
7615
- }
7616
- if (has_valid_data) {
7614
+ for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) {
7615
+ frag_c KQ_c[ncols/frag_n];
7617
7616
#pragma unroll
7618
- for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
7619
- frag_a_K K_a;
7620
- nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
7617
+ for (int j = 0; j < ncols/frag_n; ++j) {
7618
+ nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
7619
+ }
7620
+ if (has_valid_data) {
7621
7621
#pragma unroll
7622
- for (int j = 0; j < ncols/frag_n; ++j) {
7623
- nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
7622
+ for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
7623
+ frag_a_K K_a;
7624
+ nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
7625
+ #pragma unroll
7626
+ for (int j = 0; j < ncols/frag_n; ++j) {
7627
+ nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
7628
+ }
7624
7629
}
7625
7630
}
7626
- }
7627
7631
#pragma unroll
7628
- for (int j0 = 0; j0 < ncols; j0 += frag_n) {
7629
- nvcuda::wmma::store_matrix_sync(KQ + j0*D_padded + frag_m*threadIdx.y, KQ_c[j0/frag_n], D_padded, nvcuda::wmma::mem_col_major);
7632
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
7633
+ nvcuda::wmma::store_matrix_sync(KQ + j0*D_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], D_padded, nvcuda::wmma::mem_col_major);
7634
+ }
7630
7635
}
7631
7636
7632
7637
__syncthreads();
@@ -7687,31 +7692,40 @@ static __global__ void flash_attn_ext_f16(
7687
7692
}
7688
7693
}
7689
7694
7690
- frag_c VKQ_c[ncols/frag_n];
7695
+ frag_c VKQ_c[D/tc_vals_per_iter][ ncols/frag_n];
7691
7696
#pragma unroll
7692
- for (int j = 0; j < ncols/frag_n; ++j) {
7693
- nvcuda::wmma::fill_fragment(VKQ_c[j], 0.0f);
7694
- }
7695
-
7696
- #pragma unroll
7697
- for (int k0 = 0; k0 < D; k0 += 16) {
7698
- if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) {
7699
- break;
7697
+ for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) {
7698
+ #pragma unroll
7699
+ for (int j = 0; j < ncols/frag_n; ++j) {
7700
+ nvcuda::wmma::fill_fragment(VKQ_c[i_KQ_0/tc_vals_per_iter][j], 0.0f);
7700
7701
}
7701
7702
7702
- frag_a_V v_a;
7703
- nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k0)*stride_KV + frag_m*threadIdx.y, stride_KV);
7704
- #pragma unroll
7705
- for (int j = 0; j < ncols/frag_n; ++j) {
7706
- nvcuda::wmma::mma_sync(VKQ_c[j], v_a, KQ_b[k0/16][j], VKQ_c[j]);
7703
+ #pragma unroll
7704
+ for (int k0 = 0; k0 < D; k0 += 16) {
7705
+ if (256 % D != 0 && k_VKQ_0 + k0 >= ne11) {
7706
+ break;
7707
+ }
7708
+
7709
+ frag_a_V v_a;
7710
+ nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k0)*stride_KV + i_KQ_0 + frag_m*threadIdx.y, stride_KV);
7711
+ #pragma unroll
7712
+ for (int j = 0; j < ncols/frag_n; ++j) {
7713
+ nvcuda::wmma::mma_sync(VKQ_c[i_KQ_0/tc_vals_per_iter][j], v_a, KQ_b[k0/16][j], VKQ_c[i_KQ_0/tc_vals_per_iter][j]);
7714
+ }
7707
7715
}
7708
7716
}
7709
7717
7710
7718
__syncthreads();
7711
7719
7712
7720
#pragma unroll
7713
- for (int j0 = 0; j0 < ncols; j0 += frag_n) {
7714
- nvcuda::wmma::store_matrix_sync(KQ + j0*D_padded + frag_m*threadIdx.y, VKQ_c[j0/frag_n], D_padded, nvcuda::wmma::mem_col_major);
7721
+ for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += tc_vals_per_iter) {
7722
+ #pragma unroll
7723
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
7724
+ nvcuda::wmma::store_matrix_sync(
7725
+ KQ + j0*D_padded + i_KQ_0 + frag_m*threadIdx.y,
7726
+ VKQ_c[i_KQ_0/tc_vals_per_iter][j0/frag_n],
7727
+ D_padded, nvcuda::wmma::mem_col_major);
7728
+ }
7715
7729
}
7716
7730
7717
7731
__syncthreads();
@@ -11453,7 +11467,7 @@ inline void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, const ggml
11453
11467
cols_per_block = 8;
11454
11468
}
11455
11469
const int frag_m = cols_per_block == 8 ? 32 : 16;
11456
- const int nwarps = Q->ne[0] / frag_m;
11470
+ const int nwarps = ( Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m;
11457
11471
const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
11458
11472
const dim3 block_dim(WARP_SIZE, nwarps, 1);
11459
11473
const size_t shmem = 0;
0 commit comments