Skip to content

Commit 12c415b

Browse files
Alcpzarthw
authored andcommitted
sycl : Fixes to broken builds and test-backend-ops (ggml-org#10257)
* Fixes broken build for the SYCL CUDA backend caused by non-explicit gemm call in outprod (merged in with RWKV6 in Optimize RWKV6 Operator Naming and Implement Multi-core CPU/ SYCL Acceleration ggml-org#10133) * Marks permuted MUL_MAT as unsupported to be able to run test-backend-ops * Fixes asserts in norm to fix debug builds.
1 parent c5e90fe commit 12c415b

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

ggml/src/ggml-sycl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4263,6 +4263,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
42634263
if (op->op == GGML_OP_MUL_MAT) {
42644264
a = op->src[0];
42654265
b = op->src[1];
4266+
if (ggml_is_permuted(a) || ggml_is_permuted(b)) {
4267+
// TODO: fix like https://github.com/ggerganov/llama.cpp/pull/10021
4268+
return false;
4269+
}
42664270
} else {
42674271
a = op->src[2];
42684272
b = op->src[1];

ggml/src/ggml-sycl/norm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep
88

99
const int nthreads = item_ct1.get_local_range(2);
1010
const int nwarps = nthreads / WARP_SIZE;
11-
assert(nwarps % WARP_SIZE == 0);
1211
sycl::float2 mean_var = sycl::float2(0.f, 0.f);
1312

1413
for (int col = tid; col < ncols; col += block_size) {
@@ -55,7 +54,6 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
5554
int end = start + group_size;
5655
const int nthreads = item_ct1.get_local_range(2);
5756
const int nwarps = nthreads / WARP_SIZE;
58-
assert(nwarps % WARP_SIZE == 0);
5957
start += item_ct1.get_local_id(2);
6058
int nreduce = nwarps / WARP_SIZE;
6159

@@ -144,7 +142,6 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
144142
const int tid = item_ct1.get_local_id(2);
145143
const int nthreads = item_ct1.get_local_range(2);
146144
const int nwarps = nthreads / WARP_SIZE;
147-
assert(nwarps % WARP_SIZE == 0);
148145
float tmp = 0.0f; // partial sum for thread in warp
149146

150147
for (int col = tid; col < ncols; col += block_size) {
@@ -204,6 +201,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
204201
}
205202
else {
206203
const int work_group_size = ggml_sycl_info().work_group_size(device_id);
204+
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
207205
const sycl::range<3> block_dims(1, 1, work_group_size);
208206
/*
209207
DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
@@ -248,6 +246,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
248246
}
249247
else {
250248
const int work_group_size = ggml_sycl_info().work_group_size(device_id);
249+
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
251250
const sycl::range<3> block_dims(1, 1, work_group_size);
252251
/*
253252
DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
@@ -296,6 +295,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
296295
}
297296
else {
298297
const int work_group_size = ggml_sycl_info().work_group_size(device_id);
298+
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
299299
const sycl::range<3> block_dims(1, 1, work_group_size);
300300
/*
301301
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed

ggml/src/ggml-sycl/outprod.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <sycl/sycl.hpp>
2+
#include <oneapi/mkl.hpp>
23
#include "outprod.hpp"
34

45

@@ -39,7 +40,7 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* sr
3940

4041
try {
4142
// Perform matrix multiplication using oneMKL GEMM
42-
oneapi::mkl::blas::gemm(*stream,
43+
oneapi::mkl::blas::column_major::gemm(*stream,
4344
oneapi::mkl::transpose::nontrans, src1_op,
4445
ne0, ne1, ne01,
4546
alpha,

0 commit comments

Comments
 (0)