Skip to content

Commit 105332c

Browse files
authored
metal : add BS=1 kernel for flash attention (#6508)
* metal : add BS=1 kernel for flash attention (wip) * metal : support more than 1 warps * metal : opts * metal : opt * metal : switch to parallel reduce * metal : reduce registers * metal : simplify * metal : initial FA vec kernel
1 parent 260cdb2 commit 105332c

File tree

2 files changed

+361
-32
lines changed

2 files changed

+361
-32
lines changed

ggml-metal.m

Lines changed: 87 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@
183183
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
184184
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
185185
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
186+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
187+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
186188
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
187189
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
188190
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@@ -621,12 +623,14 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
621623
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
622624
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
623625
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
624-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
625-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
626-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
627-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
628-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
629-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
626+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
627+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
628+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
629+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
630+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
631+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
632+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true);
633+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true);
630634
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
631635
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
632636
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
@@ -2563,19 +2567,32 @@ static enum ggml_status ggml_metal_graph_compute(
25632567

25642568
id<MTLComputePipelineState> pipeline = nil;
25652569

2566-
switch (ne00) {
2567-
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
2568-
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
2569-
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
2570-
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
2571-
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
2572-
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2573-
default:
2574-
{
2575-
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
2576-
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
2577-
GGML_ASSERT(false && "add template specialization for this size");
2578-
}
2570+
if (ne01 > 1 || (ne00%128 != 0)) {
2571+
switch (ne00) {
2572+
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
2573+
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
2574+
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
2575+
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
2576+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
2577+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2578+
default:
2579+
{
2580+
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
2581+
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
2582+
GGML_ASSERT(false && "add template specialization for this size");
2583+
}
2584+
}
2585+
} else {
2586+
switch (ne00) {
2587+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
2588+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
2589+
default:
2590+
{
2591+
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
2592+
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
2593+
GGML_ASSERT(false && "add template specialization for this size");
2594+
}
2595+
}
25792596
}
25802597

25812598
// TODO: extend if necessary
@@ -2609,24 +2626,62 @@ static enum ggml_status ggml_metal_graph_compute(
26092626
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
26102627
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
26112628

2612-
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
2613-
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
2629+
// half8x8 kernel
2630+
if (ne01 > 1 || (ne00%128 != 0)) {
2631+
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
2632+
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
26142633

2615-
GGML_ASSERT(nqptg <= 32);
2616-
GGML_ASSERT(nqptg % 8 == 0);
2617-
GGML_ASSERT(ncpsg % 32 == 0);
2634+
GGML_ASSERT(nqptg <= 32);
2635+
GGML_ASSERT(nqptg % 8 == 0);
2636+
GGML_ASSERT(ncpsg % 32 == 0);
26182637

2619-
// simdgroups per threadgroup (a.k.a. warps)
2620-
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
2621-
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4;
2638+
// simdgroups per threadgroup (a.k.a. warps)
2639+
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
2640+
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4;
26222641

2623-
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);
2642+
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);
26242643

2625-
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
2626-
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
2627-
[encoder setThreadgroupMemoryLength:smem atIndex:0];
2644+
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
2645+
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
2646+
[encoder setThreadgroupMemoryLength:smem atIndex:0];
2647+
2648+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2649+
} else {
2650+
// half1x4 kernel
2651+
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
2652+
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
26282653

2629-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2654+
GGML_ASSERT(nqptg <= 32);
2655+
GGML_ASSERT(nqptg % 1 == 0);
2656+
GGML_ASSERT(ncpsg % 32 == 0);
2657+
2658+
// simdgroups per threadgroup (a.k.a. warps)
2659+
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
2660+
const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32));
2661+
2662+
int64_t nsg = 1;
2663+
while (nsg <= nsgt) {
2664+
nsg *= 2;
2665+
}
2666+
nsg /= 2;
2667+
2668+
// require power of 2
2669+
//{
2670+
// int64_t nsgm = 1;
2671+
// while (nsgm < nsg) {
2672+
// nsgm *= 2;
2673+
// }
2674+
// GGML_ASSERT(nsg == nsgm);
2675+
//}
2676+
2677+
const size_t smem = (nqptg*(ne00 + nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
2678+
2679+
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
2680+
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
2681+
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
2682+
2683+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
2684+
}
26302685
} break;
26312686
case GGML_OP_DUP:
26322687
case GGML_OP_CPY:

0 commit comments

Comments
 (0)