@@ -94,7 +94,9 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, int n
94
94
}
95
95
96
96
// non-contiguous kernel (slow)
97
- static __global__ void concat_f32_non_cont (
97
+ template <int dim>
98
+ static __global__ void __launch_bounds__ (CUDA_CONCAT_BLOCK_SIZE)
99
+ concat_f32_non_cont(
98
100
const char * src0,
99
101
const char * src1,
100
102
char * dst,
@@ -121,22 +123,28 @@ static __global__ void concat_f32_non_cont(
121
123
uint64_t nb0,
122
124
uint64_t nb1,
123
125
uint64_t nb2,
124
- uint64_t nb3,
125
- int32_t dim) {
126
+ uint64_t nb3){
127
+ static_assert (dim >= 0 && dim <= 3 );
128
+
126
129
const int64_t i3 = blockIdx .z ;
127
130
const int64_t i2 = blockIdx .y ;
128
131
const int64_t i1 = blockIdx .x ;
129
132
130
- int64_t o[4 ] = {0 , 0 , 0 , 0 };
131
- o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
132
-
133
133
const float * x;
134
134
135
- for (int i0 = threadIdx .x ; i0 < ne0; i0 += blockDim .x ) {
135
+ for (int64_t i0 = threadIdx .x ; i0 < ne0; i0 += blockDim .x ) {
136
136
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
137
137
x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
138
138
} else {
139
- x = (const float *)(src1 + (i3 - o[3 ])*nb13 + (i2 - o[2 ])*nb12 + (i1 - o[1 ])*nb11 + (i0 - o[0 ])*nb10);
139
+ if constexpr (dim == 0 ) {
140
+ x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + i1 * nb11 + (i0 - ne00) * nb10);
141
+ } else if constexpr (dim == 1 ) {
142
+ x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + (i1 - ne01) * nb11 + i0 * nb10);
143
+ } else if constexpr (dim == 2 ) {
144
+ x = (const float *) (src1 + i3 * nb13 + (i2 - ne02) * nb12 + i1 * nb11 + i0 * nb10);
145
+ } else if constexpr (dim == 3 ) {
146
+ x = (const float *) (src1 + (i3 - ne03) * nb13 + i2 * nb12 + i1 * nb11 + i0 * nb10);
147
+ }
140
148
}
141
149
142
150
float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -182,15 +190,32 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
182
190
}
183
191
} else {
184
192
dim3 grid_dim (dst->ne [1 ], dst->ne [2 ], dst->ne [3 ]);
185
- concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0 , stream>>> (
186
- (const char *)src0->data ,
187
- (const char *)src1->data ,
188
- ( char *)dst->data ,
193
+ auto launch_kernel = [&](auto dim) {
194
+ concat_f32_non_cont<dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0 , stream>>> (
195
+ (const char *) src0->data , (const char *) src1->data , (char *) dst->data ,
189
196
src0->ne [0 ], src0->ne [1 ], src0->ne [2 ], src0->ne [3 ],
190
197
src0->nb [0 ], src0->nb [1 ], src0->nb [2 ], src0->nb [3 ],
191
198
src1->ne [0 ], src1->ne [1 ], src1->ne [2 ], src1->ne [3 ],
192
199
src1->nb [0 ], src1->nb [1 ], src1->nb [2 ], src1->nb [3 ],
193
- dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ],
194
- dst->nb [0 ], dst->nb [1 ], dst->nb [2 ], dst->nb [3 ], dim);
200
+ dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ],
201
+ dst->nb [0 ], dst->nb [1 ], dst->nb [2 ], dst->nb [3 ]);
202
+ };
203
+ switch (dim) {
204
+ case 0 :
205
+ launch_kernel (std::integral_constant<int , 0 >{});
206
+ break ;
207
+ case 1 :
208
+ launch_kernel (std::integral_constant<int , 1 >{});
209
+ break ;
210
+ case 2 :
211
+ launch_kernel (std::integral_constant<int , 2 >{});
212
+ break ;
213
+ case 3 :
214
+ launch_kernel (std::integral_constant<int , 3 >{});
215
+ break ;
216
+ default :
217
+ GGML_ABORT (" Invalid dim: %d" , dim);
218
+ break ;
219
+ }
195
220
}
196
221
}
0 commit comments