|
83 | 83 | GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
84 | 84 | GGML_METAL_KERNEL_TYPE_NORM,
|
85 | 85 | GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
| 86 | + GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, |
86 | 87 | GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
87 | 88 | GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
88 | 89 | GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
@@ -544,6 +545,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
544 | 545 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
545 | 546 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
546 | 547 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
|
| 548 | + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); |
547 | 549 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
548 | 550 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
549 | 551 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
@@ -806,6 +808,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
806 | 808 | }
|
807 | 809 | return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
808 | 810 | case GGML_OP_SSM_CONV:
|
| 811 | + case GGML_OP_SSM_SCAN: |
809 | 812 | return true;
|
810 | 813 | case GGML_OP_MUL_MAT:
|
811 | 814 | case GGML_OP_MUL_MAT_ID:
|
@@ -1575,6 +1578,88 @@ static enum ggml_status ggml_metal_graph_compute(
|
1575 | 1578 |
|
1576 | 1579 | [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1577 | 1580 | } break;
|
| 1581 | + case GGML_OP_SSM_SCAN: |
| 1582 | + { |
| 1583 | + struct ggml_tensor * src3 = gf->nodes[i]->src[3]; |
| 1584 | + struct ggml_tensor * src4 = gf->nodes[i]->src[4]; |
| 1585 | + struct ggml_tensor * src5 = gf->nodes[i]->src[5]; |
| 1586 | + |
| 1587 | + GGML_ASSERT(src3); |
| 1588 | + GGML_ASSERT(src4); |
| 1589 | + GGML_ASSERT(src5); |
| 1590 | + |
| 1591 | + size_t offs_src3 = 0; |
| 1592 | + size_t offs_src4 = 0; |
| 1593 | + size_t offs_src5 = 0; |
| 1594 | + |
| 1595 | + id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; |
| 1596 | + id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; |
| 1597 | + id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; |
| 1598 | + |
| 1599 | + const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30); |
| 1600 | + const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); |
| 1601 | + |
| 1602 | + const uint64_t nb30 = src3->nb[0]; |
| 1603 | + const uint64_t nb31 = src3->nb[1]; |
| 1604 | + |
| 1605 | + const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); |
| 1606 | + const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41); |
| 1607 | + const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); |
| 1608 | + |
| 1609 | + const uint64_t nb40 = src4->nb[0]; |
| 1610 | + const uint64_t nb41 = src4->nb[1]; |
| 1611 | + const uint64_t nb42 = src4->nb[2]; |
| 1612 | + |
| 1613 | + const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50); |
| 1614 | + const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51); |
| 1615 | + const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); |
| 1616 | + |
| 1617 | + const uint64_t nb50 = src5->nb[0]; |
| 1618 | + const uint64_t nb51 = src5->nb[1]; |
| 1619 | + const uint64_t nb52 = src5->nb[2]; |
| 1620 | + |
| 1621 | + const int64_t d_state = ne00; |
| 1622 | + const int64_t d_inner = ne01; |
| 1623 | + const int64_t n_seq_tokens = ne11; |
| 1624 | + const int64_t n_seqs = ne02; |
| 1625 | + |
| 1626 | + id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; |
| 1627 | + |
| 1628 | + [encoder setComputePipelineState:pipeline]; |
| 1629 | + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; |
| 1630 | + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; |
| 1631 | + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; |
| 1632 | + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; |
| 1633 | + [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; |
| 1634 | + [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; |
| 1635 | + [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; |
| 1636 | + |
| 1637 | + [encoder setBytes:&d_state length:sizeof(d_state) atIndex:7]; |
| 1638 | + [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8]; |
| 1639 | + [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9]; |
| 1640 | + [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10]; |
| 1641 | + |
| 1642 | + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11]; |
| 1643 | + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12]; |
| 1644 | + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13]; |
| 1645 | + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; |
| 1646 | + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; |
| 1647 | + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; |
| 1648 | + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; |
| 1649 | + [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18]; |
| 1650 | + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19]; |
| 1651 | + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20]; |
| 1652 | + [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21]; |
| 1653 | + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; |
| 1654 | + [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23]; |
| 1655 | + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; |
| 1656 | + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; |
| 1657 | + [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26]; |
| 1658 | + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; |
| 1659 | + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; |
| 1660 | + |
| 1661 | + [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; |
| 1662 | + } break; |
1578 | 1663 | case GGML_OP_MUL_MAT:
|
1579 | 1664 | {
|
1580 | 1665 | GGML_ASSERT(ne00 == ne10);
|
|
0 commit comments