Skip to content

Commit 1c58096

Browse files
committed
sycl: Enhance OP support judgment
1 parent bee1cec commit 1c58096

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

ggml/src/ggml-sycl.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5733,8 +5733,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
57335733
case GGML_OP_CONCAT:
57345734
{
57355735
ggml_type src0_type = op->src[0]->type;
5736-
int dim = op->op_params[0];
5737-
return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16 && dim == 2;
5736+
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
57385737
} break;
57395738
case GGML_OP_DUP:
57405739
case GGML_OP_ARGMAX:
@@ -5797,9 +5796,23 @@ static bool ggml_backend_sycl_device_supports_buft(ggml_backend_dev_t dev, ggml_
57975796
return buft_ctx->device == sycl_ctx->device;
57985797
}
57995798

5799+
static int64_t get_op_batch_size(const ggml_tensor * op) {
5800+
switch (op->op) {
5801+
case GGML_OP_GET_ROWS:
5802+
return op->ne[1]; // this will increse the speed of prefill in test
5803+
case GGML_OP_MUL_MAT:
5804+
return op->ne[1];
5805+
case GGML_OP_MUL_MAT_ID:
5806+
case GGML_OP_ROPE:
5807+
return op->ne[2];
5808+
default:
5809+
return ggml_nrows(op);
5810+
}
5811+
}
5812+
58005813
static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
58015814
const int min_batch_size = 32;
5802-
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS && op->op != GGML_OP_MUL_MAT_ID;
5815+
return get_op_batch_size(op) >= min_batch_size;
58035816
GGML_UNUSED(dev);
58045817
}
58055818

ggml/src/ggml-sycl/concat.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
106106
concat_f32_dim1(x, y, dst, ne0, ne01, item_ct1);
107107
});
108108
break;
109-
default:
109+
case 2:
110110
stream->parallel_for(
111111
sycl::nd_range<3>(gridDim *
112112
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),

0 commit comments

Comments
 (0)