Skip to content

Commit 62b95f9

Browse files
committed
cuda : support non-contiguous src1 in get_rows
1 parent 2e4db48 commit 62b95f9

File tree

3 files changed

+142
-79
lines changed

3 files changed

+142
-79
lines changed

ggml-cuda.cu

Lines changed: 88 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1686,31 +1686,39 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
16861686
}
16871687

16881688
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
1689-
static __global__ void k_get_rows(const void * x, const int32_t * y, dst_t * dst, const int ncols) {
1690-
const int col = (blockIdx.x*blockDim.x + threadIdx.x)*2;
1691-
const int row = blockDim.y*blockIdx.y + threadIdx.y;
1692-
1693-
if (col >= ncols) {
1689+
static __global__ void k_get_rows(
1690+
const void * src0, const int32_t * src1, dst_t * dst,
1691+
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
1692+
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
1693+
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
1694+
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
1695+
size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
1696+
1697+
const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
1698+
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
1699+
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
1700+
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
1701+
1702+
if (i00 >= ne00) {
16941703
return;
16951704
}
16961705

1697-
const int r = y[row];
1706+
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
16981707

1699-
// copy x[r*ncols + col] to dst[row*ncols + col]
1700-
const int xi = r*ncols + col;
1701-
const int di = row*ncols + col;
1708+
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
1709+
const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
17021710

1703-
const int ib = xi/qk; // block index
1704-
const int iqs = (xi%qk)/qr; // quant index
1705-
const int iybs = di - di%qk; // y block start index
1711+
const int ib = i00/qk; // block index
1712+
const int iqs = (i00%qk)/qr; // quant index
1713+
const int iybs = i00 - i00%qk; // dst block start index
17061714
const int y_offset = qr == 1 ? 1 : qk/2;
17071715

17081716
// dequantize
17091717
dfloat2 v;
1710-
dequantize_kernel(x, ib, iqs, v);
1718+
dequantize_kernel(src0_row, ib, iqs, v);
17111719

1712-
dst[iybs + iqs + 0] = v.x;
1713-
dst[iybs + iqs + y_offset] = v.y;
1720+
dst_row[iybs + iqs + 0] = v.x;
1721+
dst_row[iybs + iqs + y_offset] = v.y;
17141722
}
17151723

17161724
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
@@ -5055,11 +5063,35 @@ static __global__ void im2col_f32_f16(
50555063
}
50565064

50575065
template<int qk, int qr, dequantize_kernel_t dq>
5058-
static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
5066+
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
5067+
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
5068+
5069+
GGML_TENSOR_BINARY_OP_LOCALS
5070+
50595071
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
5060-
const int block_num_x = (ncols + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
5061-
const dim3 block_nums(block_num_x, nrows, 1);
5062-
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols);
5072+
const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
5073+
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
5074+
5075+
// strides in elements
5076+
//const size_t s0 = nb0 / ggml_element_size(dst);
5077+
const size_t s1 = nb1 / ggml_element_size(dst);
5078+
const size_t s2 = nb2 / ggml_element_size(dst);
5079+
const size_t s3 = nb3 / ggml_element_size(dst);
5080+
5081+
const size_t s10 = nb10 / ggml_element_size(src1);
5082+
const size_t s11 = nb11 / ggml_element_size(src1);
5083+
const size_t s12 = nb12 / ggml_element_size(src1);
5084+
//const size_t s13 = nb13 / ggml_element_size(src1);
5085+
5086+
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
5087+
src0_dd, src1_dd, dst_dd,
5088+
ne00, /*ne01, ne02, ne03,*/
5089+
/*ne10, ne11,*/ ne12, /*ne13,*/
5090+
/* s0,*/ s1, s2, s3,
5091+
/* nb00,*/ nb01, nb02, nb03,
5092+
s10, s11, s12/*, s13*/);
5093+
5094+
(void) dst;
50635095
}
50645096

50655097
template<float (*bin_op)(const float, const float)>
@@ -5071,7 +5103,6 @@ struct bin_bcast_cuda {
50715103

50725104
GGML_TENSOR_BINARY_OP_LOCALS
50735105

5074-
50755106
int nr0 = ne10/ne0;
50765107
int nr1 = ne11/ne1;
50775108
int nr2 = ne12/ne2;
@@ -5119,26 +5150,28 @@ struct bin_bcast_cuda {
51195150
int64_t ne12 = cne1[2];
51205151
int64_t ne13 = cne1[3];
51215152

5122-
//size_t nb0 = cnb0[0];
5153+
size_t nb0 = cnb0[0];
51235154
size_t nb1 = cnb0[1];
51245155
size_t nb2 = cnb0[2];
51255156
size_t nb3 = cnb0[3];
51265157

5127-
//size_t nb10 = cnb1[0];
5158+
size_t nb10 = cnb1[0];
51285159
size_t nb11 = cnb1[1];
51295160
size_t nb12 = cnb1[2];
51305161
size_t nb13 = cnb1[3];
51315162

5132-
//size_t s0 = nb0 / sizeof(src1_t);
5163+
size_t s0 = nb0 / sizeof(src1_t);
51335164
size_t s1 = nb1 / sizeof(src1_t);
51345165
size_t s2 = nb2 / sizeof(src1_t);
51355166
size_t s3 = nb3 / sizeof(src1_t);
51365167

5137-
//size_t s10 = nb10 / sizeof(src1_t);
5168+
size_t s10 = nb10 / sizeof(src1_t);
51385169
size_t s11 = nb11 / sizeof(src1_t);
51395170
size_t s12 = nb12 / sizeof(src1_t);
51405171
size_t s13 = nb13 / sizeof(src1_t);
51415172

5173+
GGML_ASSERT(s0 == 1);
5174+
GGML_ASSERT(s10 == 1);
51425175

51435176
const int block_size = 128;
51445177

@@ -6449,36 +6482,34 @@ static void ggml_cuda_op_get_rows(
64496482

64506483
GGML_ASSERT(src1->type == GGML_TYPE_I32);
64516484
GGML_ASSERT(dst->type == GGML_TYPE_F32);
6452-
GGML_ASSERT(ggml_is_contiguous(src0));
6453-
GGML_ASSERT(ggml_is_contiguous(src1));
6454-
GGML_ASSERT(ggml_is_contiguous(dst));
64556485

6456-
const int ncols = src0->ne[0];
6457-
const int nrows = ggml_nelements(src1);
6486+
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
6487+
GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
6488+
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
64586489

64596490
const int32_t * src1_i32 = (const int32_t *) src1_d;
64606491

64616492
switch (src0->type) {
64626493
case GGML_TYPE_F16:
6463-
get_rows_cuda<1, 1, convert_f16>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
6494+
get_rows_cuda<1, 1, convert_f16>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
64646495
break;
64656496
case GGML_TYPE_F32:
6466-
get_rows_cuda<1, 1, convert_f32>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
6497+
get_rows_cuda<1, 1, convert_f32>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
64676498
break;
64686499
case GGML_TYPE_Q4_0:
6469-
get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
6500+
get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
64706501
break;
64716502
case GGML_TYPE_Q4_1:
6472-
get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
6503+
get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
64736504
break;
64746505
case GGML_TYPE_Q5_0:
6475-
get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
6506+
get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
64766507
break;
64776508
case GGML_TYPE_Q5_1:
6478-
get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
6509+
get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
64796510
break;
64806511
case GGML_TYPE_Q8_0:
6481-
get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
6512+
get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
64826513
break;
64836514
default:
64846515
// TODO: k-quants
@@ -8286,11 +8317,8 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
82868317

82878318
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
82888319

8289-
if (src1->backend == GGML_BACKEND_GPU) {
8290-
src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1];
8291-
} else {
8292-
src1_row.data = (char *) src1->data + i01*src1->nb[1];
8293-
}
8320+
src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1];
8321+
src1_row.data = (char *) src1->data + i01*src1->nb[1];
82948322

82958323
dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1];
82968324

@@ -8707,9 +8735,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
87078735
func = ggml_cuda_repeat;
87088736
break;
87098737
case GGML_OP_GET_ROWS:
8710-
if (ggml_is_contiguous(tensor->src[1])) {
8711-
func = ggml_cuda_get_rows;
8712-
}
8738+
func = ggml_cuda_get_rows;
87138739
break;
87148740
case GGML_OP_DUP:
87158741
func = ggml_cuda_dup;
@@ -9215,14 +9241,28 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
92159241
}
92169242
return true;
92179243
} break;
9244+
case GGML_OP_GET_ROWS:
9245+
{
9246+
switch (op->src[0]->type) {
9247+
case GGML_TYPE_F16:
9248+
case GGML_TYPE_F32:
9249+
case GGML_TYPE_Q4_0:
9250+
case GGML_TYPE_Q4_1:
9251+
case GGML_TYPE_Q5_0:
9252+
case GGML_TYPE_Q5_1:
9253+
case GGML_TYPE_Q8_0:
9254+
return true;
9255+
default:
9256+
return false;
9257+
}
9258+
} break;
92189259
case GGML_OP_NONE:
92199260
case GGML_OP_RESHAPE:
92209261
case GGML_OP_VIEW:
92219262
case GGML_OP_PERMUTE:
92229263
case GGML_OP_TRANSPOSE:
92239264
case GGML_OP_NORM:
92249265
case GGML_OP_REPEAT:
9225-
case GGML_OP_GET_ROWS:
92269266
case GGML_OP_DUP:
92279267
case GGML_OP_ADD:
92289268
case GGML_OP_MUL:
@@ -9298,7 +9338,9 @@ static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * use
92989338
UNUSED(params);
92999339
}
93009340

9301-
extern "C" int ggml_backend_cuda_reg_devices() {
9341+
extern "C" int ggml_backend_cuda_reg_devices();
9342+
9343+
int ggml_backend_cuda_reg_devices() {
93029344
int device_count = ggml_cuda_get_device_count();
93039345
//int device_count = 1; // DEBUG: some tools require delaying CUDA initialization
93049346
for (int i = 0; i < device_count; i++) {

llama.cpp

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4254,12 +4254,13 @@ struct llm_build_context {
42544254

42554255
// select experts
42564256
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok]
4257+
cb(selected_experts->src[0], "ffn_moe_argsort", il);
4258+
42574259
ggml_tensor * weights = ggml_get_rows(ctx0,
4258-
ggml_reshape_3d(ctx0, probs, 1, n_experts, n_tokens), selected_experts);
4260+
ggml_reshape_3d(ctx0, probs, 1, n_experts, n_tokens), selected_experts);
42594261
cb(weights, "ffn_moe_weights", il);
42604262

4261-
weights = ggml_reshape_2d(ctx0, weights,
4262-
n_experts_per_tok, n_tokens); // [n_tokens, num_experts_per_tok]
4263+
weights = ggml_reshape_2d(ctx0, weights, n_experts_per_tok, n_tokens); // [n_tokens, num_experts_per_tok]
42634264

42644265
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
42654266
cb(weights_sum, "ffn_moe_weights_sum", il);
@@ -4268,7 +4269,7 @@ struct llm_build_context {
42684269
cb(weights, "ffn_moe_weights_norm", il);
42694270

42704271
// compute expert outputs
4271-
ggml_tensor * moe_out;
4272+
ggml_tensor * moe_out = nullptr;
42724273

42734274
for (int i = 0; i < n_experts_per_tok; ++i) {
42744275
ggml_tensor * cur_expert;
@@ -4279,19 +4280,19 @@ struct llm_build_context {
42794280
ggml_tensor ** ffn_down_exp = (ggml_tensor **) model.layers[il].ffn_down_exp;
42804281

42814282
ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, ffn_up_exp, n_experts, selected_experts, i, cur);
4282-
cb(cur_up, "ffn_up", il);
4283+
cb(cur_up, "ffn_moe_up", il);
42834284

42844285
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, ffn_gate_exp, n_experts, selected_experts, i, cur);
4285-
cb(cur_gate, "ffn_gate", il);
4286+
cb(cur_gate, "ffn_moe_gate", il);
42864287

42874288
cur_gate = ggml_silu(ctx0, cur_gate);
4288-
cb(cur_gate, "ffn_silu", il);
4289+
cb(cur_gate, "ffn_moe_silu", il);
42894290

42904291
cur_expert = ggml_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd]
4291-
cb(cur_expert, "ffn_gate_par", il);
4292+
cb(cur_expert, "ffn_moe_gate_par", il);
42924293

42934294
cur_expert = ggml_mul_mat_id(ctx0, ffn_down_exp, n_experts, selected_experts, i, cur_expert); // [n_tokens, n_embd]
4294-
cb(cur_expert, "ffn_down", il);
4295+
cb(cur_expert, "ffn_moe_down", il);
42954296

42964297
cur_expert = ggml_mul(ctx0, cur_expert,
42974298
ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
@@ -5562,10 +5563,15 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
55625563

55635564
{ "ffn_moe_logits", OFFLOAD_FUNC },
55645565
{ "ffn_moe_probs", OFFLOAD_FUNC },
5565-
{ "ffn_moe_weights", OFFLOAD_FUNC_NOP },
5566+
{ "ffn_moe_argsort", OFFLOAD_FUNC },
5567+
{ "ffn_moe_weights", OFFLOAD_FUNC },
55665568
{ "ffn_moe_weights_sum", OFFLOAD_FUNC },
55675569
{ "ffn_moe_weights_norm", OFFLOAD_FUNC },
55685570
{ "ffn_moe_weighted", OFFLOAD_FUNC },
5571+
{ "ffn_moe_up", OFFLOAD_FUNC },
5572+
{ "ffn_moe_gate", OFFLOAD_FUNC },
5573+
{ "ffn_moe_gate_par", OFFLOAD_FUNC },
5574+
{ "ffn_moe_down", OFFLOAD_FUNC },
55695575
{ "ffn_moe_out", OFFLOAD_FUNC },
55705576

55715577
{ "l_out", OFFLOAD_FUNC },

0 commit comments

Comments
 (0)