2
2
3
3
template <int block_size>
4
4
static __global__ void ssm_conv_f32 (
5
- const float * src0, const float * src1, const float * src2,
6
- const int src0_nb1, const int src0_nb2,
7
- const int src1_nb0, const int src1_nb1, const int src1_nb2,
8
- const int src2_nb1,
5
+ const float * src0, const float * src1,
6
+ const int src0_nb0, const int src0_nb1, const int src0_nb2,
7
+ const int src1_nb1,
9
8
float * dst,
10
9
const int dst_nb0, const int dst_nb1, const int dst_nb2,
11
- const int nc, const int nr, const int n_t , const int n_s) {
10
+ const int nc, const int ncs, const int nr, const int n_t , const int n_s) {
12
11
13
12
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
14
13
const int tid = threadIdx .x ;
@@ -24,118 +23,80 @@ static __global__ void ssm_conv_f32(
24
23
const int ir1 = min (ir0 + dr, nr);
25
24
const int ir = ir1 - ir0;
26
25
27
- // TODO: maybe require src0 to have d_conv columns instead of (d_conv - 1)?
28
- // This would avoid having to copy into an intermediate buffer, but the state would be bigger.
29
-
30
- // float * s = (float *) params->wdata + (nc*dr + CACHE_LINE_SIZE_F32) * ith;
31
- extern __shared__ float wdata_f32[]; // work buffer for all threads
32
- float * s = (float *) wdata_f32 + nc*dr*ith;
33
-
34
26
for (int i3 = 0 ; i3 < n_s; ++i3) {
35
- float * s0 = (float *) ((char *) src0 + ir0*src0_nb1 + i3*src0_nb2); // {d_conv, d_inner, n_s}
36
-
37
- // copy the state into working memory
38
- // can't use memcpy because (d_conv) != (d_conv - 1)
39
- for (int i1 = 0 ; i1 < ir; ++i1) {
40
- for (int i0 = 0 ; i0 < nc - 1 ; ++i0) {
41
- s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1 )];
42
- }
43
- }
44
-
45
27
for (int i2 = 0 ; i2 < n_t ; ++i2) {
46
- float * x = (float *) ((char *) dst + ir0* dst_nb0 + i2* dst_nb1 + i3* dst_nb2); // {d_inner, n_t, n_s}
47
- float * x0 = (float *) ((char *) src1 + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s}
48
- float * c = (float *) ((char *) src2 + ir0*src2_nb1); // {d_conv, d_inner}
49
-
50
- // shift state left
51
- // memmove(s, s + 1, (nc*ir - 1) * sizeof(float));
52
- for (int i4 = 0 ; i4 < nc*ir - 1 ; ++i4) {
53
- s[i4] = s[i4+1 ];
54
- }
28
+ // {d_conv - 1 + n_t, d_inner, n_seqs}
29
+ // sliding window
30
+ const float * s = (const float *) ((const char *) src0 + ir0*src0_nb1 + i2*src0_nb0 + i3*src0_nb2); // {d_conv, d_inner, n_s}
31
+ const float * c = (const float *) ((const char *) src1 + ir0*src1_nb1); // {d_conv, d_inner}
32
+ float * x = (float *) ((char *) dst + ir0*dst_nb0 + i2*dst_nb1 + i3*dst_nb2); // {d_inner, n_t, n_s}
55
33
34
+ // TODO: transpose the output for smaller strides for big batches?
56
35
// d_inner
57
- for (int i1 = 0 ; i1 < ir; ++i1) {
58
- // insert x on the last column
59
- s[(nc - 1 ) + i1*nc] = x0[i1];
60
- }
61
-
62
- // it seems a little faster when this is separate from the state shift
63
36
for (int i1 = 0 ; i1 < ir; ++i1) {
64
37
// rowwise dot product
65
38
// NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
66
39
float sumf = 0 .0f ;
40
+
41
+ // d_conv
67
42
for (int i0 = 0 ; i0 < nc; ++i0) {
68
- int i = i0 + i1*nc;
69
- sumf += s[i] * c[i];
43
+ sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
70
44
}
71
45
x[i1] = sumf;
72
46
}
73
47
}
74
-
75
- // copy the state out of it
76
- for (int i1 = 0 ; i1 < ir; ++i1) {
77
- for (int i0 = 0 ; i0 < nc - 1 ; ++i0) {
78
- s0[i0 + i1*(nc - 1 )] = s[1 + i0 + i1*nc];
79
- }
80
- }
81
48
}
82
49
}
83
50
84
51
static void ssm_conv_f32_cuda (
85
- const float * src0, const float * src1, const float * src2,
86
- const int src0_nb1, const int src0_nb2,
87
- const int src1_nb0, const int src1_nb1, const int src1_nb2,
88
- const int src2_nb1,
52
+ const float * src0, const float * src1,
53
+ const int src0_nb0, const int src0_nb1, const int src0_nb2,
54
+ const int src1_nb1,
89
55
float * dst,
90
56
const int dst_nb0, const int dst_nb1, const int dst_nb2,
91
- const int nc, const int nr, const int n_t , const int n_s,
57
+ const int nc, const int ncs, const int nr, const int n_t , const int n_s,
92
58
cudaStream_t stream) {
93
59
94
60
const dim3 block_dims (WARP_SIZE, 1 , 1 );
95
61
const int nblocks = 1 ; // TODO
96
- const int shmem_size = nc * (nr + WARP_SIZE - 1 ) * sizeof (float ); // TODO
97
62
98
- ssm_conv_f32<WARP_SIZE><<<nblocks, block_dims, shmem_size, stream>>> (
99
- src0, src1, src2,
100
- src0_nb1, src0_nb2,
101
- src1_nb0, src1_nb1, src1_nb2,
102
- src2_nb1,
63
+ ssm_conv_f32<WARP_SIZE><<<nblocks, block_dims, 0 , stream>>> (
64
+ src0, src1,
65
+ src0_nb0, src0_nb1, src0_nb2,
66
+ src1_nb1,
103
67
dst,
104
68
dst_nb0, dst_nb1, dst_nb2,
105
- nc, nr, n_t , n_s);
69
+ nc, ncs, nr, n_t , n_s);
106
70
}
107
71
108
72
void ggml_cuda_op_ssm_conv (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
109
- const struct ggml_tensor * src0 = dst->src [0 ]; // conv_state
110
- const struct ggml_tensor * src1 = dst->src [1 ]; // x
111
- const struct ggml_tensor * src2 = dst->src [2 ]; // conv1d.weight
73
+ const struct ggml_tensor * src0 = dst->src [0 ]; // conv_x
74
+ const struct ggml_tensor * src1 = dst->src [1 ]; // conv1d.weight
112
75
113
- const int nc = src2->ne [0 ]; // d_conv
76
+ const int nc = src1->ne [0 ]; // d_conv
77
+ const int ncs = src0->ne [0 ]; // d_conv - 1 + n_t
114
78
const int nr = src0->ne [1 ]; // d_inner
115
- const int n_t = src1 ->ne [1 ]; // tokens per sequence
116
- const int n_s = src0 ->ne [2 ]; // number of sequences in the batch
79
+ const int n_t = dst ->ne [1 ]; // tokens per sequence
80
+ const int n_s = dst ->ne [2 ]; // number of sequences in the batch
117
81
118
- GGML_ASSERT (ggml_are_same_shape (src1, dst) );
82
+ GGML_ASSERT ( dst-> ne [ 0 ] == nr );
119
83
GGML_ASSERT (src0->nb [0 ] == sizeof (float ));
120
84
GGML_ASSERT (src1->nb [0 ] == sizeof (float ));
121
- GGML_ASSERT (src2->nb [0 ] == sizeof (float ));
122
85
GGML_ASSERT (src0->nb [1 ] == src0->ne [0 ]*sizeof (float ));
123
86
124
87
const float * src0_d = (const float *)src0->data ;
125
88
const float * src1_d = (const float *)src1->data ;
126
- const float * src2_d = (const float *)src2->data ;
127
89
float * dst_d = (float *)dst->data ;
128
90
cudaStream_t stream = ctx.stream ();
129
91
130
92
GGML_ASSERT (src0->type == GGML_TYPE_F32);
131
93
GGML_ASSERT ( dst->type == GGML_TYPE_F32);
132
94
133
- ssm_conv_f32_cuda (src0_d, src1_d, src2_d,
134
- src0->nb [1 ], src0->nb [2 ],
135
- src1->nb [0 ], src1->nb [1 ], src1->nb [2 ],
136
- src2->nb [1 ],
95
+ ssm_conv_f32_cuda (src0_d, src1_d,
96
+ src0->nb [0 ], src0->nb [1 ], src0->nb [2 ],
97
+ src1->nb [1 ],
137
98
dst_d,
138
99
dst->nb [0 ], dst->nb [1 ], dst->nb [2 ],
139
- nc, nr, n_t , n_s,
100
+ nc, ncs, nr, n_t , n_s,
140
101
stream);
141
102
}
0 commit comments