|
82 | 82 | GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
83 | 83 | GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
84 | 84 | GGML_METAL_KERNEL_TYPE_NORM,
|
| 85 | + GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, |
| 86 | + GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, |
85 | 87 | GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
86 | 88 | GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
87 | 89 | GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
@@ -542,6 +544,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
542 | 544 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
543 | 545 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
544 | 546 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
| 547 | + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); |
| 548 | + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); |
545 | 549 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
546 | 550 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
547 | 551 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
@@ -803,6 +807,9 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
803 | 807 | return false;
|
804 | 808 | }
|
805 | 809 | return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
| 810 | + case GGML_OP_SSM_CONV: |
| 811 | + case GGML_OP_SSM_SCAN: |
| 812 | + return true; |
806 | 813 | case GGML_OP_MUL_MAT:
|
807 | 814 | case GGML_OP_MUL_MAT_ID:
|
808 | 815 | return ctx->support_simdgroup_reduction &&
|
@@ -1538,6 +1545,121 @@ static enum ggml_status ggml_metal_graph_compute(
|
1538 | 1545 | [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1539 | 1546 | }
|
1540 | 1547 | } break;
|
| 1548 | + case GGML_OP_SSM_CONV: |
| 1549 | + { |
| 1550 | + GGML_ASSERT(src0t == GGML_TYPE_F32); |
| 1551 | + GGML_ASSERT(src1t == GGML_TYPE_F32); |
| 1552 | + |
| 1553 | + GGML_ASSERT(ggml_is_contiguous(src0)); |
| 1554 | + GGML_ASSERT(ggml_is_contiguous(src1)); |
| 1555 | + |
| 1556 | + id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; |
| 1557 | + |
| 1558 | + [encoder setComputePipelineState:pipeline]; |
| 1559 | + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; |
| 1560 | + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; |
| 1561 | + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; |
| 1562 | + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; |
| 1563 | + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; |
| 1564 | + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; |
| 1565 | + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; |
| 1566 | + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; |
| 1567 | + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; |
| 1568 | + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; |
| 1569 | + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; |
| 1570 | + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11]; |
| 1571 | + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12]; |
| 1572 | + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; |
| 1573 | + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; |
| 1574 | + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15]; |
| 1575 | + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16]; |
| 1576 | + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17]; |
| 1577 | + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18]; |
| 1578 | + |
| 1579 | + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; |
| 1580 | + } break; |
| 1581 | + case GGML_OP_SSM_SCAN: |
| 1582 | + { |
| 1583 | + struct ggml_tensor * src3 = gf->nodes[i]->src[3]; |
| 1584 | + struct ggml_tensor * src4 = gf->nodes[i]->src[4]; |
| 1585 | + struct ggml_tensor * src5 = gf->nodes[i]->src[5]; |
| 1586 | + |
| 1587 | + GGML_ASSERT(src3); |
| 1588 | + GGML_ASSERT(src4); |
| 1589 | + GGML_ASSERT(src5); |
| 1590 | + |
| 1591 | + size_t offs_src3 = 0; |
| 1592 | + size_t offs_src4 = 0; |
| 1593 | + size_t offs_src5 = 0; |
| 1594 | + |
| 1595 | + id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; |
| 1596 | + id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; |
| 1597 | + id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; |
| 1598 | + |
| 1599 | + const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30); |
| 1600 | + const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); |
| 1601 | + |
| 1602 | + const uint64_t nb30 = src3->nb[0]; |
| 1603 | + const uint64_t nb31 = src3->nb[1]; |
| 1604 | + |
| 1605 | + const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); |
| 1606 | + const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41); |
| 1607 | + const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); |
| 1608 | + |
| 1609 | + const uint64_t nb40 = src4->nb[0]; |
| 1610 | + const uint64_t nb41 = src4->nb[1]; |
| 1611 | + const uint64_t nb42 = src4->nb[2]; |
| 1612 | + |
| 1613 | + const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50); |
| 1614 | + const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51); |
| 1615 | + const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); |
| 1616 | + |
| 1617 | + const uint64_t nb50 = src5->nb[0]; |
| 1618 | + const uint64_t nb51 = src5->nb[1]; |
| 1619 | + const uint64_t nb52 = src5->nb[2]; |
| 1620 | + |
| 1621 | + const int64_t d_state = ne00; |
| 1622 | + const int64_t d_inner = ne01; |
| 1623 | + const int64_t n_seq_tokens = ne11; |
| 1624 | + const int64_t n_seqs = ne02; |
| 1625 | + |
| 1626 | + id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; |
| 1627 | + |
| 1628 | + [encoder setComputePipelineState:pipeline]; |
| 1629 | + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; |
| 1630 | + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; |
| 1631 | + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; |
| 1632 | + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; |
| 1633 | + [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; |
| 1634 | + [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; |
| 1635 | + [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; |
| 1636 | + |
| 1637 | + [encoder setBytes:&d_state length:sizeof(d_state) atIndex:7]; |
| 1638 | + [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8]; |
| 1639 | + [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9]; |
| 1640 | + [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10]; |
| 1641 | + |
| 1642 | + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11]; |
| 1643 | + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12]; |
| 1644 | + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13]; |
| 1645 | + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; |
| 1646 | + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; |
| 1647 | + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; |
| 1648 | + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; |
| 1649 | + [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18]; |
| 1650 | + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19]; |
| 1651 | + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20]; |
| 1652 | + [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21]; |
| 1653 | + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; |
| 1654 | + [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23]; |
| 1655 | + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; |
| 1656 | + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; |
| 1657 | + [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26]; |
| 1658 | + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; |
| 1659 | + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; |
| 1660 | + |
| 1661 | + [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; |
| 1662 | + } break; |
1541 | 1663 | case GGML_OP_MUL_MAT:
|
1542 | 1664 | {
|
1543 | 1665 | GGML_ASSERT(ne00 == ne10);
|
|
0 commit comments