|
82 | 82 | GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
83 | 83 | GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
84 | 84 | GGML_METAL_KERNEL_TYPE_NORM,
|
| 85 | + GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, |
85 | 86 | GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
86 | 87 | GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
87 | 88 | GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
@@ -538,6 +539,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
538 | 539 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
539 | 540 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
540 | 541 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
| 542 | + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); |
541 | 543 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
542 | 544 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
543 | 545 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
@@ -799,6 +801,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
799 | 801 | return false;
|
800 | 802 | }
|
801 | 803 | return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
| 804 | + case GGML_OP_SSM_CONV: |
| 805 | + return true; |
802 | 806 | case GGML_OP_MUL_MAT:
|
803 | 807 | case GGML_OP_MUL_MAT_ID:
|
804 | 808 | return ctx->support_simdgroup_reduction &&
|
@@ -1531,6 +1535,39 @@ static enum ggml_status ggml_metal_graph_compute(
|
1531 | 1535 | [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1532 | 1536 | }
|
1533 | 1537 | } break;
|
| 1538 | + case GGML_OP_SSM_CONV: |
| 1539 | + { |
| 1540 | + GGML_ASSERT(src0t == GGML_TYPE_F32); |
| 1541 | + GGML_ASSERT(src1t == GGML_TYPE_F32); |
| 1542 | + |
| 1543 | + GGML_ASSERT(ggml_is_contiguous(src0)); |
| 1544 | + GGML_ASSERT(ggml_is_contiguous(src1)); |
| 1545 | + |
| 1546 | + id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; |
| 1547 | + |
| 1548 | + [encoder setComputePipelineState:pipeline]; |
| 1549 | + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; |
| 1550 | + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; |
| 1551 | + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; |
| 1552 | + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; |
| 1553 | + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; |
| 1554 | + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; |
| 1555 | + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; |
| 1556 | + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; |
| 1557 | + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; |
| 1558 | + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; |
| 1559 | + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; |
| 1560 | + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11]; |
| 1561 | + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12]; |
| 1562 | + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; |
| 1563 | + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; |
| 1564 | + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15]; |
| 1565 | + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16]; |
| 1566 | + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17]; |
| 1567 | + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18]; |
| 1568 | + |
| 1569 | + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, 1) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; |
| 1570 | + } break; |
1534 | 1571 | case GGML_OP_MUL_MAT:
|
1535 | 1572 | {
|
1536 | 1573 | GGML_ASSERT(ne00 == ne10);
|
|
0 commit comments