Skip to content

Commit b77cb69

Browse files
committed
Update CUDA ops and tests to match implementation from commit 8fb57ac (llama : use im2col and mul_mat to perform convolution for Mamba); GPU version breaks with assert because of unsupported MUL_MAT
1 parent 065ef82 commit b77cb69

File tree

3 files changed

+56
-92
lines changed

3 files changed

+56
-92
lines changed

ggml-cuda/ssm_conv.cu

Lines changed: 33 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22

33
template <int block_size>
44
static __global__ void ssm_conv_f32(
5-
const float * src0, const float * src1, const float * src2,
6-
const int src0_nb1, const int src0_nb2,
7-
const int src1_nb0, const int src1_nb1, const int src1_nb2,
8-
const int src2_nb1,
5+
const float * src0, const float * src1,
6+
const int src0_nb0, const int src0_nb1, const int src0_nb2,
7+
const int src1_nb1,
98
float * dst,
109
const int dst_nb0, const int dst_nb1, const int dst_nb2,
11-
const int nc, const int nr, const int n_t, const int n_s) {
10+
const int nc, const int ncs, const int nr, const int n_t, const int n_s) {
1211

1312
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
1413
const int tid = threadIdx.x;
@@ -24,118 +23,80 @@ static __global__ void ssm_conv_f32(
2423
const int ir1 = min(ir0 + dr, nr);
2524
const int ir = ir1 - ir0;
2625

27-
// TODO: maybe require src0 to have d_conv columns instead of (d_conv - 1)?
28-
// This would avoid having to copy into an intermediate buffer, but the state would be bigger.
29-
30-
// float * s = (float *) params->wdata + (nc*dr + CACHE_LINE_SIZE_F32) * ith;
31-
extern __shared__ float wdata_f32[]; // work buffer for all threads
32-
float * s = (float *) wdata_f32 + nc*dr*ith;
33-
3426
for (int i3 = 0; i3 < n_s; ++i3) {
35-
float * s0 = (float *) ((char *) src0 + ir0*src0_nb1 + i3*src0_nb2); // {d_conv, d_inner, n_s}
36-
37-
// copy the state into working memory
38-
// can't use memcpy because (d_conv) != (d_conv - 1)
39-
for (int i1 = 0; i1 < ir; ++i1) {
40-
for (int i0 = 0; i0 < nc - 1; ++i0) {
41-
s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1)];
42-
}
43-
}
44-
4527
for (int i2 = 0; i2 < n_t; ++i2) {
46-
float * x = (float *) ((char *) dst + ir0* dst_nb0 + i2* dst_nb1 + i3* dst_nb2); // {d_inner, n_t, n_s}
47-
float * x0 = (float *) ((char *) src1 + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s}
48-
float * c = (float *) ((char *) src2 + ir0*src2_nb1); // {d_conv, d_inner}
49-
50-
// shift state left
51-
//memmove(s, s + 1, (nc*ir - 1) * sizeof(float));
52-
for (int i4 = 0; i4 < nc*ir - 1; ++i4) {
53-
s[i4] = s[i4+1];
54-
}
28+
// {d_conv - 1 + n_t, d_inner, n_seqs}
29+
// sliding window
30+
const float * s = (const float *) ((const char *) src0 + ir0*src0_nb1 + i2*src0_nb0 + i3*src0_nb2); // {d_conv, d_inner, n_s}
31+
const float * c = (const float *) ((const char *) src1 + ir0*src1_nb1); // {d_conv, d_inner}
32+
float * x = (float *) ((char *) dst + ir0*dst_nb0 + i2*dst_nb1 + i3*dst_nb2); // {d_inner, n_t, n_s}
5533

34+
// TODO: transpose the output for smaller strides for big batches?
5635
// d_inner
57-
for (int i1 = 0; i1 < ir; ++i1) {
58-
// insert x on the last column
59-
s[(nc - 1) + i1*nc] = x0[i1];
60-
}
61-
62-
// it seems a little faster when this is separate from the state shift
6336
for (int i1 = 0; i1 < ir; ++i1) {
6437
// rowwise dot product
6538
// NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
6639
float sumf = 0.0f;
40+
41+
// d_conv
6742
for (int i0 = 0; i0 < nc; ++i0) {
68-
int i = i0 + i1*nc;
69-
sumf += s[i] * c[i];
43+
sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
7044
}
7145
x[i1] = sumf;
7246
}
7347
}
74-
75-
// copy the state out of it
76-
for (int i1 = 0; i1 < ir; ++i1) {
77-
for (int i0 = 0; i0 < nc - 1; ++i0) {
78-
s0[i0 + i1*(nc - 1)] = s[1 + i0 + i1*nc];
79-
}
80-
}
8148
}
8249
}
8350

8451
static void ssm_conv_f32_cuda(
85-
const float * src0, const float * src1, const float * src2,
86-
const int src0_nb1, const int src0_nb2,
87-
const int src1_nb0, const int src1_nb1, const int src1_nb2,
88-
const int src2_nb1,
52+
const float * src0, const float * src1,
53+
const int src0_nb0, const int src0_nb1, const int src0_nb2,
54+
const int src1_nb1,
8955
float * dst,
9056
const int dst_nb0, const int dst_nb1, const int dst_nb2,
91-
const int nc, const int nr, const int n_t, const int n_s,
57+
const int nc, const int ncs, const int nr, const int n_t, const int n_s,
9258
cudaStream_t stream) {
9359

9460
const dim3 block_dims(WARP_SIZE, 1, 1);
9561
const int nblocks = 1; // TODO
96-
const int shmem_size = nc * (nr + WARP_SIZE - 1) * sizeof(float); // TODO
9762

98-
ssm_conv_f32<WARP_SIZE><<<nblocks, block_dims, shmem_size, stream>>>(
99-
src0, src1, src2,
100-
src0_nb1, src0_nb2,
101-
src1_nb0, src1_nb1, src1_nb2,
102-
src2_nb1,
63+
ssm_conv_f32<WARP_SIZE><<<nblocks, block_dims, 0, stream>>>(
64+
src0, src1,
65+
src0_nb0, src0_nb1, src0_nb2,
66+
src1_nb1,
10367
dst,
10468
dst_nb0, dst_nb1, dst_nb2,
105-
nc, nr, n_t, n_s);
69+
nc, ncs, nr, n_t, n_s);
10670
}
10771

10872
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
109-
const struct ggml_tensor * src0 = dst->src[0]; // conv_state
110-
const struct ggml_tensor * src1 = dst->src[1]; // x
111-
const struct ggml_tensor * src2 = dst->src[2]; // conv1d.weight
73+
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
74+
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
11275

113-
const int nc = src2->ne[0]; // d_conv
76+
const int nc = src1->ne[0]; // d_conv
77+
const int ncs = src0->ne[0]; // d_conv - 1 + n_t
11478
const int nr = src0->ne[1]; // d_inner
115-
const int n_t = src1->ne[1]; // tokens per sequence
116-
const int n_s = src0->ne[2]; // number of sequences in the batch
79+
const int n_t = dst->ne[1]; // tokens per sequence
80+
const int n_s = dst->ne[2]; // number of sequences in the batch
11781

118-
GGML_ASSERT(ggml_are_same_shape(src1, dst));
82+
GGML_ASSERT( dst->ne[0] == nr);
11983
GGML_ASSERT(src0->nb[0] == sizeof(float));
12084
GGML_ASSERT(src1->nb[0] == sizeof(float));
121-
GGML_ASSERT(src2->nb[0] == sizeof(float));
12285
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
12386

12487
const float * src0_d = (const float *)src0->data;
12588
const float * src1_d = (const float *)src1->data;
126-
const float * src2_d = (const float *)src2->data;
12789
float * dst_d = (float *)dst->data;
12890
cudaStream_t stream = ctx.stream();
12991

13092
GGML_ASSERT(src0->type == GGML_TYPE_F32);
13193
GGML_ASSERT( dst->type == GGML_TYPE_F32);
13294

133-
ssm_conv_f32_cuda(src0_d, src1_d, src2_d,
134-
src0->nb[1], src0->nb[2],
135-
src1->nb[0], src1->nb[1], src1->nb[2],
136-
src2->nb[1],
95+
ssm_conv_f32_cuda(src0_d, src1_d,
96+
src0->nb[0], src0->nb[1], src0->nb[2],
97+
src1->nb[1],
13798
dst_d,
13899
dst->nb[0], dst->nb[1], dst->nb[2],
139-
nc, nr, n_t, n_s,
100+
nc, ncs, nr, n_t, n_s,
140101
stream);
141102
}

ggml-cuda/ssm_scan.cu

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@ static __global__ void ssm_scan_f32(
55
const float * src0, const float * src1, const float * src2, const float * src3,
66
const float * src4, const float * src5,
77
const int src0_nb1, const int src0_nb2,
8-
const int src1_nb0, const int src1_nb1, const int src1_nb2,
8+
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
99
const int src2_nb0, const int src2_nb1, const int src2_nb2,
1010
const int src3_nb1,
1111
const int src4_nb1, const int src4_nb2,
1212
const int src5_nb1, const int src5_nb2,
1313
float * dst,
14-
const int dst_nb0, const int dst_nb1, const int dst_nb2,
1514
const int nc, const int nr, const int n_t, const int n_s) {
1615

1716
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -30,13 +29,17 @@ static __global__ void ssm_scan_f32(
3029

3130
for (int i3 = 0; i3 < n_s; ++i3) {
3231
for (int i2 = 0; i2 < n_t; ++i2) {
33-
float * y = (float *) ((char *) dst + ir0* dst_nb0 + i2* dst_nb1 + i3* dst_nb2); // {d_inner, n_t, n_s}
34-
float * s = (float *) ((char *) src0 + ir0*src0_nb1 + i3*src0_nb2); // {d_state, d_inner, n_s}
35-
float * x = (float *) ((char *) src1 + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s}
36-
float * dt = (float *) ((char *) src2 + ir0*src2_nb0 + i2*src2_nb1 + i3*src2_nb2); // {d_inner, n_t, n_s}
37-
float * A = (float *) ((char *) src3 + ir0*src3_nb1); // {d_state, d_inner}
38-
float * B = (float *) ((char *) src4 + i2*src4_nb1 + i3*src4_nb2); // {d_state, n_t, n_s}
39-
float * C = (float *) ((char *) src5 + i2*src5_nb1 + i3*src5_nb2); // {d_state, n_t, n_s}
32+
const float * s0 = (const float *) ((const char *) src0 + ir0*src0_nb1 + i3*src0_nb2); // {d_state, d_inner, n_s}
33+
const float * x = (const float *) ((const char *) src1 + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s}
34+
const float * dt = (const float *) ((const char *) src2 + ir0*src2_nb0 + i2*src2_nb1 + i3*src2_nb2); // {d_inner, n_t, n_s}
35+
const float * A = (const float *) ((const char *) src3 + ir0*src3_nb1); // {d_state, d_inner}
36+
const float * B = (const float *) ((const char *) src4 + i2*src4_nb1 + i3*src4_nb2); // {d_state, n_t, n_s}
37+
const float * C = (const float *) ((const char *) src5 + i2*src5_nb1 + i3*src5_nb2); // {d_state, n_t, n_s}
38+
float * y = (float *) ((char *) dst + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s}
39+
float * s = (float *) ((char *) dst + ir0*src0_nb1 + i3*src0_nb2 + src1_nb3); // {d_state, d_inner, n_s}
40+
41+
// use the output as the source for the next token-wise iterations
42+
if (i2 > 0) { s0 = s; }
4043

4144
// d_inner
4245
for (int i1 = 0; i1 < ir; ++i1) {
@@ -48,7 +51,7 @@ static __global__ void ssm_scan_f32(
4851
for (int i0 = 0; i0 < nc; ++i0) {
4952
int i = i0 + i1*nc;
5053
// state = prev_state * dA + dB * x
51-
float state = (s[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
54+
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
5255
// y = rowwise_dotprod(state, C)
5356
sumf += state * C[i0];
5457
s[i] = state;
@@ -63,13 +66,12 @@ static void ssm_scan_f32_cuda(
6366
const float * src0, const float * src1, const float * src2, const float * src3,
6467
const float * src4, const float * src5,
6568
const int src0_nb1, const int src0_nb2,
66-
const int src1_nb0, const int src1_nb1, const int src1_nb2,
69+
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
6770
const int src2_nb0, const int src2_nb1, const int src2_nb2,
6871
const int src3_nb1,
6972
const int src4_nb1, const int src4_nb2,
7073
const int src5_nb1, const int src5_nb2,
7174
float * dst,
72-
const int dst_nb0, const int dst_nb1, const int dst_nb2,
7375
const int nc, const int nr, const int n_t, const int n_s,
7476
cudaStream_t stream) {
7577

@@ -80,13 +82,12 @@ static void ssm_scan_f32_cuda(
8082
src0, src1, src2, src3,
8183
src4, src5,
8284
src0_nb1, src0_nb2,
83-
src1_nb0, src1_nb1, src1_nb2,
85+
src1_nb0, src1_nb1, src1_nb2, src1_nb3,
8486
src2_nb0, src2_nb1, src2_nb2,
8587
src3_nb1,
8688
src4_nb1, src4_nb2,
8789
src5_nb1, src5_nb2,
8890
dst,
89-
dst_nb0, dst_nb1, dst_nb2,
9091
nc, nr, n_t, n_s);
9192
}
9293

@@ -103,7 +104,7 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
103104
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
104105
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
105106

106-
GGML_ASSERT(ggml_nelements(src1) == ggml_nelements(dst));
107+
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
107108
GGML_ASSERT(src0->nb[0] == sizeof(float));
108109
GGML_ASSERT(src1->nb[0] == sizeof(float));
109110
GGML_ASSERT(src2->nb[0] == sizeof(float));
@@ -112,6 +113,10 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
112113
GGML_ASSERT(src5->nb[0] == sizeof(float));
113114
// required for the dot product between s and C
114115
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
116+
// required for per-sequence offsets for states
117+
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
118+
// required to get correct offset for state destination (i.e. src1->nb[3])
119+
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
115120

116121
const float * src0_d = (const float *)src0->data;
117122
const float * src1_d = (const float *)src1->data;
@@ -129,13 +134,12 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
129134
src0_d, src1_d, src2_d, src3_d,
130135
src4_d, src5_d,
131136
src0->nb[1], src0->nb[2],
132-
src1->nb[0], src1->nb[1], src1->nb[2],
137+
src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
133138
src2->nb[0], src2->nb[1], src2->nb[2],
134139
src3->nb[1],
135140
src4->nb[1], src4->nb[2],
136141
src5->nb[1], src5->nb[2],
137142
dst_d,
138-
dst->nb[0], dst->nb[1], dst->nb[2],
139143
nc, nr, n_t, n_s,
140144
stream);
141145
}

tests/test-backend-ops.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1579,10 +1579,9 @@ struct test_ssm_conv : public test_case {
15791579
: type(type), d_conv(d_conv), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
15801580

15811581
ggml_tensor * build_graph(ggml_context * ctx) override {
1582-
ggml_tensor * s = ggml_new_tensor_3d(ctx, type, d_conv - 1, d_inner, n_seqs);
1583-
ggml_tensor * x = ggml_new_tensor_3d(ctx, type, d_inner, n_seq_tokens, n_seqs);
1582+
ggml_tensor * sx = ggml_new_tensor_3d(ctx, type, d_conv - 1 + n_seq_tokens, d_inner, n_seqs);
15841583
ggml_tensor * c = ggml_new_tensor_2d(ctx, type, d_conv, d_inner);
1585-
ggml_tensor * out = ggml_ssm_conv(ctx, s, x, c);
1584+
ggml_tensor * out = ggml_ssm_conv(ctx, sx, c);
15861585
return out;
15871586
}
15881587
};

0 commit comments

Comments
 (0)