@@ -1686,31 +1686,39 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
1686
1686
}
1687
1687
1688
1688
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t >
1689
- static __global__ void k_get_rows (const void * x, const int32_t * y, dst_t * dst, const int ncols) {
1690
- const int col = (blockIdx .x *blockDim .x + threadIdx .x )*2 ;
1691
- const int row = blockDim .y *blockIdx .y + threadIdx .y ;
1692
-
1693
- if (col >= ncols) {
1689
+ static __global__ void k_get_rows (
1690
+ const void * src0, const int32_t * src1, dst_t * dst,
1691
+ int64_t ne00, /* int64_t ne01, int64_t ne02, int64_t ne03,*/
1692
+ /* int64_t ne10, int64_t ne11,*/ int64_t ne12, /* int64_t ne13,*/
1693
+ /* size_t s0,*/ size_t s1, size_t s2, size_t s3,
1694
+ /* size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
1695
+ size_t s10, size_t s11, size_t s12/* , size_t s13*/ ) {
1696
+
1697
+ const int i00 = (blockIdx .x *blockDim .x + threadIdx .x )*2 ;
1698
+ const int i10 = blockDim .y *blockIdx .y + threadIdx .y ;
1699
+ const int i11 = (blockIdx .z *blockDim .z + threadIdx .z )/ne12;
1700
+ const int i12 = (blockIdx .z *blockDim .z + threadIdx .z )%ne12;
1701
+
1702
+ if (i00 >= ne00) {
1694
1703
return ;
1695
1704
}
1696
1705
1697
- const int r = y[row ];
1706
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12 ];
1698
1707
1699
- // copy x[r*ncols + col] to dst[row*ncols + col]
1700
- const int xi = r*ncols + col;
1701
- const int di = row*ncols + col;
1708
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
1709
+ const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
1702
1710
1703
- const int ib = xi /qk; // block index
1704
- const int iqs = (xi %qk)/qr; // quant index
1705
- const int iybs = di - di %qk; // y block start index
1711
+ const int ib = i00 /qk; // block index
1712
+ const int iqs = (i00 %qk)/qr; // quant index
1713
+ const int iybs = i00 - i00 %qk; // dst block start index
1706
1714
const int y_offset = qr == 1 ? 1 : qk/2 ;
1707
1715
1708
1716
// dequantize
1709
1717
dfloat2 v;
1710
- dequantize_kernel (x , ib, iqs, v);
1718
+ dequantize_kernel (src0_row , ib, iqs, v);
1711
1719
1712
- dst [iybs + iqs + 0 ] = v.x ;
1713
- dst [iybs + iqs + y_offset] = v.y ;
1720
+ dst_row [iybs + iqs + 0 ] = v.x ;
1721
+ dst_row [iybs + iqs + y_offset] = v.y ;
1714
1722
}
1715
1723
1716
1724
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t >
@@ -5055,11 +5063,35 @@ static __global__ void im2col_f32_f16(
5055
5063
}
5056
5064
5057
5065
template <int qk, int qr, dequantize_kernel_t dq>
5058
- static void get_rows_cuda (const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
5066
+ static void get_rows_cuda (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
5067
+ const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
5068
+
5069
+ GGML_TENSOR_BINARY_OP_LOCALS
5070
+
5059
5071
const dim3 block_dims (CUDA_GET_ROWS_BLOCK_SIZE, 1 , 1 );
5060
- const int block_num_x = (ncols + 2 *CUDA_GET_ROWS_BLOCK_SIZE - 1 ) / (2 *CUDA_GET_ROWS_BLOCK_SIZE);
5061
- const dim3 block_nums (block_num_x, nrows, 1 );
5062
- k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0 , stream>>> (x, y, dst, ncols);
5072
+ const int block_num_x = (ne00 + 2 *CUDA_GET_ROWS_BLOCK_SIZE - 1 ) / (2 *CUDA_GET_ROWS_BLOCK_SIZE);
5073
+ const dim3 block_nums (block_num_x, ne10, ne11*ne12);
5074
+
5075
+ // strides in elements
5076
+ // const size_t s0 = nb0 / ggml_element_size(dst);
5077
+ const size_t s1 = nb1 / ggml_element_size (dst);
5078
+ const size_t s2 = nb2 / ggml_element_size (dst);
5079
+ const size_t s3 = nb3 / ggml_element_size (dst);
5080
+
5081
+ const size_t s10 = nb10 / ggml_element_size (src1);
5082
+ const size_t s11 = nb11 / ggml_element_size (src1);
5083
+ const size_t s12 = nb12 / ggml_element_size (src1);
5084
+ // const size_t s13 = nb13 / ggml_element_size(src1);
5085
+
5086
+ k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0 , stream>>> (
5087
+ src0_dd, src1_dd, dst_dd,
5088
+ ne00, /* ne01, ne02, ne03,*/
5089
+ /* ne10, ne11,*/ ne12, /* ne13,*/
5090
+ /* s0,*/ s1, s2, s3,
5091
+ /* nb00,*/ nb01, nb02, nb03,
5092
+ s10, s11, s12/* , s13*/ );
5093
+
5094
+ (void ) dst;
5063
5095
}
5064
5096
5065
5097
template <float (*bin_op)(const float , const float )>
@@ -5071,7 +5103,6 @@ struct bin_bcast_cuda {
5071
5103
5072
5104
GGML_TENSOR_BINARY_OP_LOCALS
5073
5105
5074
-
5075
5106
int nr0 = ne10/ne0;
5076
5107
int nr1 = ne11/ne1;
5077
5108
int nr2 = ne12/ne2;
@@ -5119,26 +5150,28 @@ struct bin_bcast_cuda {
5119
5150
int64_t ne12 = cne1[2 ];
5120
5151
int64_t ne13 = cne1[3 ];
5121
5152
5122
- // size_t nb0 = cnb0[0];
5153
+ size_t nb0 = cnb0[0 ];
5123
5154
size_t nb1 = cnb0[1 ];
5124
5155
size_t nb2 = cnb0[2 ];
5125
5156
size_t nb3 = cnb0[3 ];
5126
5157
5127
- // size_t nb10 = cnb1[0];
5158
+ size_t nb10 = cnb1[0 ];
5128
5159
size_t nb11 = cnb1[1 ];
5129
5160
size_t nb12 = cnb1[2 ];
5130
5161
size_t nb13 = cnb1[3 ];
5131
5162
5132
- // size_t s0 = nb0 / sizeof(src1_t);
5163
+ size_t s0 = nb0 / sizeof (src1_t );
5133
5164
size_t s1 = nb1 / sizeof (src1_t );
5134
5165
size_t s2 = nb2 / sizeof (src1_t );
5135
5166
size_t s3 = nb3 / sizeof (src1_t );
5136
5167
5137
- // size_t s10 = nb10 / sizeof(src1_t);
5168
+ size_t s10 = nb10 / sizeof (src1_t );
5138
5169
size_t s11 = nb11 / sizeof (src1_t );
5139
5170
size_t s12 = nb12 / sizeof (src1_t );
5140
5171
size_t s13 = nb13 / sizeof (src1_t );
5141
5172
5173
+ GGML_ASSERT (s0 == 1 );
5174
+ GGML_ASSERT (s10 == 1 );
5142
5175
5143
5176
const int block_size = 128 ;
5144
5177
@@ -6449,36 +6482,34 @@ static void ggml_cuda_op_get_rows(
6449
6482
6450
6483
GGML_ASSERT (src1->type == GGML_TYPE_I32);
6451
6484
GGML_ASSERT (dst->type == GGML_TYPE_F32);
6452
- GGML_ASSERT (ggml_is_contiguous (src0));
6453
- GGML_ASSERT (ggml_is_contiguous (src1));
6454
- GGML_ASSERT (ggml_is_contiguous (dst));
6455
6485
6456
- const int ncols = src0->ne [0 ];
6457
- const int nrows = ggml_nelements (src1);
6486
+ GGML_ASSERT (src0->nb [0 ] == ggml_type_size (src0->type ));
6487
+ GGML_ASSERT (src1->nb [0 ] == ggml_type_size (src1->type ));
6488
+ GGML_ASSERT (dst->nb [0 ] == ggml_type_size (dst->type ));
6458
6489
6459
6490
const int32_t * src1_i32 = (const int32_t *) src1_d;
6460
6491
6461
6492
switch (src0->type ) {
6462
6493
case GGML_TYPE_F16:
6463
- get_rows_cuda<1 , 1 , convert_f16>(src0_d, src1_i32, dst_d, nrows, ncols , stream);
6494
+ get_rows_cuda<1 , 1 , convert_f16>(src0, src1, dst, src0_d, src1_i32, dst_d , stream);
6464
6495
break ;
6465
6496
case GGML_TYPE_F32:
6466
- get_rows_cuda<1 , 1 , convert_f32>(src0_d, src1_i32, dst_d, nrows, ncols , stream);
6497
+ get_rows_cuda<1 , 1 , convert_f32>(src0, src1, dst, src0_d, src1_i32, dst_d , stream);
6467
6498
break ;
6468
6499
case GGML_TYPE_Q4_0:
6469
- get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_i32, dst_d, nrows, ncols , stream);
6500
+ get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d , stream);
6470
6501
break ;
6471
6502
case GGML_TYPE_Q4_1:
6472
- get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_i32, dst_d, nrows, ncols , stream);
6503
+ get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d , stream);
6473
6504
break ;
6474
6505
case GGML_TYPE_Q5_0:
6475
- get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_i32, dst_d, nrows, ncols , stream);
6506
+ get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d , stream);
6476
6507
break ;
6477
6508
case GGML_TYPE_Q5_1:
6478
- get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_i32, dst_d, nrows, ncols , stream);
6509
+ get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d , stream);
6479
6510
break ;
6480
6511
case GGML_TYPE_Q8_0:
6481
- get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_i32, dst_d, nrows, ncols , stream);
6512
+ get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d , stream);
6482
6513
break ;
6483
6514
default :
6484
6515
// TODO: k-quants
@@ -8286,11 +8317,8 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
8286
8317
8287
8318
const struct ggml_tensor * src0_row = dst->src [row_id + 2 ];
8288
8319
8289
- if (src1->backend == GGML_BACKEND_GPU) {
8290
- src1_row_extra.data_device [g_main_device] = (char *) src1_extra->data_device [g_main_device] + i01*src1->nb [1 ];
8291
- } else {
8292
- src1_row.data = (char *) src1->data + i01*src1->nb [1 ];
8293
- }
8320
+ src1_row_extra.data_device [g_main_device] = (char *) src1_extra->data_device [g_main_device] + i01*src1->nb [1 ];
8321
+ src1_row.data = (char *) src1->data + i01*src1->nb [1 ];
8294
8322
8295
8323
dst_row_extra.data_device [g_main_device] = (char *) dst_extra->data_device [g_main_device] + i01*dst->nb [1 ];
8296
8324
@@ -8707,9 +8735,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
8707
8735
func = ggml_cuda_repeat;
8708
8736
break ;
8709
8737
case GGML_OP_GET_ROWS:
8710
- if (ggml_is_contiguous (tensor->src [1 ])) {
8711
- func = ggml_cuda_get_rows;
8712
- }
8738
+ func = ggml_cuda_get_rows;
8713
8739
break ;
8714
8740
case GGML_OP_DUP:
8715
8741
func = ggml_cuda_dup;
@@ -9215,14 +9241,28 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten
9215
9241
}
9216
9242
return true ;
9217
9243
} break ;
9244
+ case GGML_OP_GET_ROWS:
9245
+ {
9246
+ switch (op->src [0 ]->type ) {
9247
+ case GGML_TYPE_F16:
9248
+ case GGML_TYPE_F32:
9249
+ case GGML_TYPE_Q4_0:
9250
+ case GGML_TYPE_Q4_1:
9251
+ case GGML_TYPE_Q5_0:
9252
+ case GGML_TYPE_Q5_1:
9253
+ case GGML_TYPE_Q8_0:
9254
+ return true ;
9255
+ default :
9256
+ return false ;
9257
+ }
9258
+ } break ;
9218
9259
case GGML_OP_NONE:
9219
9260
case GGML_OP_RESHAPE:
9220
9261
case GGML_OP_VIEW:
9221
9262
case GGML_OP_PERMUTE:
9222
9263
case GGML_OP_TRANSPOSE:
9223
9264
case GGML_OP_NORM:
9224
9265
case GGML_OP_REPEAT:
9225
- case GGML_OP_GET_ROWS:
9226
9266
case GGML_OP_DUP:
9227
9267
case GGML_OP_ADD:
9228
9268
case GGML_OP_MUL:
@@ -9298,7 +9338,9 @@ static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * use
9298
9338
UNUSED (params);
9299
9339
}
9300
9340
9301
- extern " C" int ggml_backend_cuda_reg_devices () {
9341
+ extern " C" int ggml_backend_cuda_reg_devices ();
9342
+
9343
+ int ggml_backend_cuda_reg_devices () {
9302
9344
int device_count = ggml_cuda_get_device_count ();
9303
9345
// int device_count = 1; // DEBUG: some tools require delaying CUDA initialization
9304
9346
for (int i = 0 ; i < device_count; i++) {
0 commit comments