Skip to content

Commit 20892a1

Browse files
A3shTnTslaren
authored andcommitted
CUDA: faster non-contiguous concat (ggml-org#10760)
* faster uncontiguous concat * Use a lambda to avoid code duplication Co-authored-by: Diego Devesa <[email protected]> * Update ggml/src/ggml-cuda/concat.cu * add constexpr and static assert --------- Co-authored-by: Diego Devesa <[email protected]>
1 parent 1f50972 commit 20892a1

File tree

1 file changed

+39
-14
lines changed

1 file changed

+39
-14
lines changed

ggml/src/ggml-cuda/concat.cu

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ static void concat_f32_cuda(const float * x, const float * y, float * dst, int n
9494
}
9595

9696
// non-contiguous kernel (slow)
97-
static __global__ void concat_f32_non_cont(
97+
template <int dim>
98+
static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE)
99+
concat_f32_non_cont(
98100
const char * src0,
99101
const char * src1,
100102
char * dst,
@@ -121,22 +123,28 @@ static __global__ void concat_f32_non_cont(
121123
uint64_t nb0,
122124
uint64_t nb1,
123125
uint64_t nb2,
124-
uint64_t nb3,
125-
int32_t dim) {
126+
uint64_t nb3){
127+
static_assert(dim >= 0 && dim <= 3);
128+
126129
const int64_t i3 = blockIdx.z;
127130
const int64_t i2 = blockIdx.y;
128131
const int64_t i1 = blockIdx.x;
129132

130-
int64_t o[4] = {0, 0, 0, 0};
131-
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
132-
133133
const float * x;
134134

135-
for (int i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
135+
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
136136
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
137137
x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
138138
} else {
139-
x = (const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
139+
if constexpr (dim == 0) {
140+
x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + i1 * nb11 + (i0 - ne00) * nb10);
141+
} else if constexpr (dim == 1) {
142+
x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + (i1 - ne01) * nb11 + i0 * nb10);
143+
} else if constexpr (dim == 2) {
144+
x = (const float *) (src1 + i3 * nb13 + (i2 - ne02) * nb12 + i1 * nb11 + i0 * nb10);
145+
} else if constexpr (dim == 3) {
146+
x = (const float *) (src1 + (i3 - ne03) * nb13 + i2 * nb12 + i1 * nb11 + i0 * nb10);
147+
}
140148
}
141149

142150
float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -182,15 +190,32 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
182190
}
183191
} else {
184192
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
185-
concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
186-
(const char *)src0->data,
187-
(const char *)src1->data,
188-
( char *)dst->data,
193+
auto launch_kernel = [&](auto dim) {
194+
concat_f32_non_cont<dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
195+
(const char *) src0->data, (const char *) src1->data, (char *) dst->data,
189196
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
190197
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
191198
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
192199
src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
193-
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
194-
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim);
200+
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
201+
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
202+
};
203+
switch (dim) {
204+
case 0:
205+
launch_kernel(std::integral_constant<int, 0>{});
206+
break;
207+
case 1:
208+
launch_kernel(std::integral_constant<int, 1>{});
209+
break;
210+
case 2:
211+
launch_kernel(std::integral_constant<int, 2>{});
212+
break;
213+
case 3:
214+
launch_kernel(std::integral_constant<int, 3>{});
215+
break;
216+
default:
217+
GGML_ABORT("Invalid dim: %d", dim);
218+
break;
219+
}
195220
}
196221
}

0 commit comments

Comments
 (0)