Skip to content

Commit 345d590

Browse files
committed
ggml : add ggml_ssm_conv metal impl
1 parent 7b7db0b commit 345d590

File tree

3 files changed

+115
-0
lines changed

3 files changed

+115
-0
lines changed

ggml/src/ggml-metal.m

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
GGML_METAL_KERNEL_TYPE_RMS_NORM,
8383
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
8484
GGML_METAL_KERNEL_TYPE_NORM,
85+
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
8586
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
8687
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
8788
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
@@ -538,6 +539,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
538539
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
539540
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
540541
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
542+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
541543
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
542544
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
543545
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
@@ -799,6 +801,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
799801
return false;
800802
}
801803
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
804+
case GGML_OP_SSM_CONV:
805+
return true;
802806
case GGML_OP_MUL_MAT:
803807
case GGML_OP_MUL_MAT_ID:
804808
return ctx->support_simdgroup_reduction &&
@@ -1531,6 +1535,39 @@ static enum ggml_status ggml_metal_graph_compute(
15311535
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
15321536
}
15331537
} break;
1538+
case GGML_OP_SSM_CONV:
1539+
{
1540+
GGML_ASSERT(src0t == GGML_TYPE_F32);
1541+
GGML_ASSERT(src1t == GGML_TYPE_F32);
1542+
1543+
GGML_ASSERT(ggml_is_contiguous(src0));
1544+
GGML_ASSERT(ggml_is_contiguous(src1));
1545+
1546+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
1547+
1548+
[encoder setComputePipelineState:pipeline];
1549+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1550+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1551+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1552+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1553+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1554+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1555+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1556+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1557+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1558+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
1559+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
1560+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
1561+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
1562+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1563+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
1564+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15];
1565+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16];
1566+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17];
1567+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18];
1568+
1569+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, 1) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1570+
} break;
15341571
case GGML_OP_MUL_MAT:
15351572
{
15361573
GGML_ASSERT(ne00 == ne10);

ggml/src/ggml-metal.metal

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,57 @@ kernel void kernel_diag_mask_inf_8(
667667
}
668668
}
669669

670+
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
671+
// TODO: optimize
672+
kernel void kernel_ssm_conv_f32(
673+
device const void * src0,
674+
device const void * src1,
675+
device float * dst,
676+
constant int64_t & ne00,
677+
constant int64_t & ne01,
678+
constant int64_t & ne02,
679+
constant uint64_t & nb00,
680+
constant uint64_t & nb01,
681+
constant uint64_t & nb02,
682+
constant int64_t & ne10,
683+
constant int64_t & ne11,
684+
constant uint64_t & nb10,
685+
constant uint64_t & nb11,
686+
constant int64_t & ne0,
687+
constant int64_t & ne1,
688+
constant int64_t & ne2,
689+
constant uint64_t & nb0,
690+
constant uint64_t & nb1,
691+
constant uint64_t & nb2,
692+
uint3 tgpig[[threadgroup_position_in_grid]],
693+
uint3 tpitg[[thread_position_in_threadgroup]],
694+
uint3 ntg[[threads_per_threadgroup]]) {
695+
const int64_t ir = tgpig.x;
696+
const int64_t i2 = tgpig.y;
697+
698+
const int64_t nc = ne10;
699+
const int64_t ncs = ne00;
700+
const int64_t nr = ne01;
701+
const int64_t n_t = ne1;
702+
const int64_t n_s = ne2;
703+
704+
for (int64_t i3 = 0; i3 < n_s; ++i3) {
705+
for (int64_t i2 = tpitg.x; i2 < n_t; i2 += ntg.x) {
706+
device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
707+
device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
708+
device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2);
709+
710+
float sumf = 0.0f;
711+
712+
for (int64_t i0 = 0; i0 < nc; ++i0) {
713+
sumf += s[i0] * c[i0];
714+
}
715+
716+
*x = sumf;
717+
}
718+
}
719+
}
720+
670721
kernel void kernel_norm(
671722
device const void * src0,
672723
device float * dst,

tests/test-backend-ops.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,29 @@ struct test_rms_norm : public test_case {
934934
}
935935
};
936936

937+
// GGML_OP_SSM_CONV
938+
struct test_ssm_conv : public test_case {
939+
const ggml_type type;
940+
const std::array<int64_t, 4> ne_a;
941+
const std::array<int64_t, 4> ne_b;
942+
943+
std::string vars() override {
944+
return VARS_TO_STR3(type, ne_a, ne_b);
945+
}
946+
947+
test_ssm_conv(ggml_type type = GGML_TYPE_F32,
948+
std::array<int64_t, 4> ne_a = {10, 10, 10, 1},
949+
std::array<int64_t, 4> ne_b = {3, 3, 1, 1})
950+
: type(type), ne_a(ne_a), ne_b(ne_b) {}
951+
952+
ggml_tensor * build_graph(ggml_context * ctx) override {
953+
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
954+
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
955+
ggml_tensor * out = ggml_ssm_conv(ctx, a, b);
956+
return out;
957+
}
958+
};
959+
937960
// GGML_OP_MUL_MAT
938961
struct test_mul_mat : public test_case {
939962
const ggml_type type_a;
@@ -2201,6 +2224,10 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
22012224
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
22022225
}
22032226

2227+
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1}));
2228+
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
2229+
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
2230+
22042231
for (ggml_type type_a : base_types) {
22052232
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
22062233
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));

0 commit comments

Comments
 (0)