|
183 | 183 | GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
184 | 184 | GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
185 | 185 | GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
| 186 | + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, |
| 187 | + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, |
186 | 188 | GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
187 | 189 | GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
188 | 190 | GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
@@ -621,12 +623,14 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
621 | 623 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
622 | 624 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
623 | 625 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
624 |
| - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); |
625 |
| - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); |
626 |
| - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); |
627 |
| - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); |
628 |
| - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); |
629 |
| - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); |
| 626 | + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); |
| 627 | + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); |
| 628 | + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); |
| 629 | + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); |
| 630 | + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); |
| 631 | + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); |
| 632 | + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true); |
| 633 | + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true); |
630 | 634 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
631 | 635 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
632 | 636 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
@@ -2563,19 +2567,32 @@ static enum ggml_status ggml_metal_graph_compute(
|
2563 | 2567 |
|
2564 | 2568 | id<MTLComputePipelineState> pipeline = nil;
|
2565 | 2569 |
|
2566 |
| - switch (ne00) { |
2567 |
| - case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; |
2568 |
| - case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; |
2569 |
| - case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; |
2570 |
| - case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; |
2571 |
| - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; |
2572 |
| - case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; |
2573 |
| - default: |
2574 |
| - { |
2575 |
| - GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); |
2576 |
| - GGML_METAL_LOG_ERROR("add template specialization for this size\n"); |
2577 |
| - GGML_ASSERT(false && "add template specialization for this size"); |
2578 |
| - } |
| 2570 | + if (ne01 > 1 || (ne00%128 != 0)) { |
| 2571 | + switch (ne00) { |
| 2572 | + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; |
| 2573 | + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; |
| 2574 | + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; |
| 2575 | + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; |
| 2576 | + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; |
| 2577 | + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; |
| 2578 | + default: |
| 2579 | + { |
| 2580 | + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); |
| 2581 | + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); |
| 2582 | + GGML_ASSERT(false && "add template specialization for this size"); |
| 2583 | + } |
| 2584 | + } |
| 2585 | + } else { |
| 2586 | + switch (ne00) { |
| 2587 | + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; |
| 2588 | + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; |
| 2589 | + default: |
| 2590 | + { |
| 2591 | + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); |
| 2592 | + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); |
| 2593 | + GGML_ASSERT(false && "add template specialization for this size"); |
| 2594 | + } |
| 2595 | + } |
2579 | 2596 | }
|
2580 | 2597 |
|
2581 | 2598 | // TODO: extend if necessary
|
@@ -2609,24 +2626,62 @@ static enum ggml_status ggml_metal_graph_compute(
|
2609 | 2626 | [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
2610 | 2627 | [encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
2611 | 2628 |
|
2612 |
| - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! |
2613 |
| - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! |
| 2629 | + // half8x8 kernel |
| 2630 | + if (ne01 > 1 || (ne00%128 != 0)) { |
| 2631 | + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! |
| 2632 | + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! |
2614 | 2633 |
|
2615 |
| - GGML_ASSERT(nqptg <= 32); |
2616 |
| - GGML_ASSERT(nqptg % 8 == 0); |
2617 |
| - GGML_ASSERT(ncpsg % 32 == 0); |
| 2634 | + GGML_ASSERT(nqptg <= 32); |
| 2635 | + GGML_ASSERT(nqptg % 8 == 0); |
| 2636 | + GGML_ASSERT(ncpsg % 32 == 0); |
2618 | 2637 |
|
2619 |
| - // simdgroups per threadgroup (a.k.a. warps) |
2620 |
| - // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) |
2621 |
| - const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; |
| 2638 | + // simdgroups per threadgroup (a.k.a. warps) |
| 2639 | + // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) |
| 2640 | + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; |
2622 | 2641 |
|
2623 |
| - const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); |
| 2642 | + const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); |
2624 | 2643 |
|
2625 |
| - //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); |
2626 |
| - GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); |
2627 |
| - [encoder setThreadgroupMemoryLength:smem atIndex:0]; |
| 2644 | + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); |
| 2645 | + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); |
| 2646 | + [encoder setThreadgroupMemoryLength:smem atIndex:0]; |
| 2647 | + |
| 2648 | + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; |
| 2649 | + } else { |
| 2650 | + // half1x4 kernel |
| 2651 | + const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! |
| 2652 | + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! |
2628 | 2653 |
|
2629 |
| - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; |
| 2654 | + GGML_ASSERT(nqptg <= 32); |
| 2655 | + GGML_ASSERT(nqptg % 1 == 0); |
| 2656 | + GGML_ASSERT(ncpsg % 32 == 0); |
| 2657 | + |
| 2658 | + // simdgroups per threadgroup (a.k.a. warps) |
| 2659 | + // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) |
| 2660 | + const int64_t nsgt = MAX(2, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)); |
| 2661 | + |
| 2662 | + int64_t nsg = 1; |
| 2663 | + while (nsg <= nsgt) { |
| 2664 | + nsg *= 2; |
| 2665 | + } |
| 2666 | + nsg /= 2; |
| 2667 | + |
| 2668 | + // require power of 2 |
| 2669 | + //{ |
| 2670 | + // int64_t nsgm = 1; |
| 2671 | + // while (nsgm < nsg) { |
| 2672 | + // nsgm *= 2; |
| 2673 | + // } |
| 2674 | + // GGML_ASSERT(nsg == nsgm); |
| 2675 | + //} |
| 2676 | + |
| 2677 | + const size_t smem = (nqptg*(ne00 + nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2); |
| 2678 | + |
| 2679 | + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); |
| 2680 | + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); |
| 2681 | + [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0]; |
| 2682 | + |
| 2683 | + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; |
| 2684 | + } |
2630 | 2685 | } break;
|
2631 | 2686 | case GGML_OP_DUP:
|
2632 | 2687 | case GGML_OP_CPY:
|
|
0 commit comments