@@ -65,9 +65,6 @@ struct UnaryContigFunctor
65
65
if constexpr (UnaryOperatorT::is_constant::value) {
66
66
// value of operator is known to be a known constant
67
67
constexpr resT const_val = UnaryOperatorT::constant_value;
68
- using out_ptrT =
69
- sycl::multi_ptr<resT,
70
- sycl::access ::address_space::global_space>;
71
68
72
69
auto sg = ndit.get_sub_group ();
73
70
std::uint8_t sgSize = sg.get_local_range ()[0 ];
@@ -80,8 +77,11 @@ struct UnaryContigFunctor
80
77
sycl::vec<resT, vec_sz> res_vec (const_val);
81
78
#pragma unroll
82
79
for (std::uint8_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
83
- sg.store <vec_sz>(out_ptrT (&out[base + it * sgSize]),
84
- res_vec);
80
+ auto out_multi_ptr = sycl::address_space_cast<
81
+ sycl::access ::address_space::global_space,
82
+ sycl::access ::decorated::yes>(&out[base + it * sgSize]);
83
+
84
+ sg.store <vec_sz>(out_multi_ptr, res_vec);
85
85
}
86
86
}
87
87
else {
@@ -94,13 +94,6 @@ struct UnaryContigFunctor
94
94
else if constexpr (UnaryOperatorT::supports_sg_loadstore::value &&
95
95
UnaryOperatorT::supports_vec::value)
96
96
{
97
- using in_ptrT =
98
- sycl::multi_ptr<const argT,
99
- sycl::access ::address_space::global_space>;
100
- using out_ptrT =
101
- sycl::multi_ptr<resT,
102
- sycl::access ::address_space::global_space>;
103
-
104
97
auto sg = ndit.get_sub_group ();
105
98
std::uint16_t sgSize = sg.get_local_range ()[0 ];
106
99
std::uint16_t max_sgSize = sg.get_max_local_range ()[0 ];
@@ -113,10 +106,16 @@ struct UnaryContigFunctor
113
106
114
107
#pragma unroll
115
108
for (std::uint16_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
116
- x = sg.load <vec_sz>(in_ptrT (&in[base + it * sgSize]));
109
+ auto in_multi_ptr = sycl::address_space_cast<
110
+ sycl::access ::address_space::global_space,
111
+ sycl::access ::decorated::yes>(&in[base + it * sgSize]);
112
+ auto out_multi_ptr = sycl::address_space_cast<
113
+ sycl::access ::address_space::global_space,
114
+ sycl::access ::decorated::yes>(&out[base + it * sgSize]);
115
+
116
+ x = sg.load <vec_sz>(in_multi_ptr);
117
117
sycl::vec<resT, vec_sz> res_vec = op (x);
118
- sg.store <vec_sz>(out_ptrT (&out[base + it * sgSize]),
119
- res_vec);
118
+ sg.store <vec_sz>(out_multi_ptr, res_vec);
120
119
}
121
120
}
122
121
else {
@@ -141,23 +140,23 @@ struct UnaryContigFunctor
141
140
142
141
if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
143
142
(maxsgSize == sgSize)) {
144
- using in_ptrT =
145
- sycl::multi_ptr<const argT,
146
- sycl::access ::address_space::global_space>;
147
- using out_ptrT =
148
- sycl::multi_ptr<resT,
149
- sycl::access ::address_space::global_space>;
150
143
sycl::vec<argT, vec_sz> arg_vec;
151
144
152
145
#pragma unroll
153
146
for (std::uint8_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
154
- arg_vec = sg.load <vec_sz>(in_ptrT (&in[base + it * sgSize]));
147
+ auto in_multi_ptr = sycl::address_space_cast<
148
+ sycl::access ::address_space::global_space,
149
+ sycl::access ::decorated::yes>(&in[base + it * sgSize]);
150
+ auto out_multi_ptr = sycl::address_space_cast<
151
+ sycl::access ::address_space::global_space,
152
+ sycl::access ::decorated::yes>(&out[base + it * sgSize]);
153
+
154
+ arg_vec = sg.load <vec_sz>(in_multi_ptr);
155
155
#pragma unroll
156
156
for (std::uint8_t k = 0 ; k < vec_sz; ++k) {
157
157
arg_vec[k] = op (arg_vec[k]);
158
158
}
159
- sg.store <vec_sz>(out_ptrT (&out[base + it * sgSize]),
160
- arg_vec);
159
+ sg.store <vec_sz>(out_multi_ptr, arg_vec);
161
160
}
162
161
}
163
162
else {
@@ -179,24 +178,24 @@ struct UnaryContigFunctor
179
178
180
179
if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
181
180
(maxsgSize == sgSize)) {
182
- using in_ptrT =
183
- sycl::multi_ptr<const argT,
184
- sycl::access ::address_space::global_space>;
185
- using out_ptrT =
186
- sycl::multi_ptr<resT,
187
- sycl::access ::address_space::global_space>;
188
181
sycl::vec<argT, vec_sz> arg_vec;
189
182
sycl::vec<resT, vec_sz> res_vec;
190
183
191
184
#pragma unroll
192
185
for (std::uint8_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
193
- arg_vec = sg.load <vec_sz>(in_ptrT (&in[base + it * sgSize]));
186
+ auto in_multi_ptr = sycl::address_space_cast<
187
+ sycl::access ::address_space::global_space,
188
+ sycl::access ::decorated::yes>(&in[base + it * sgSize]);
189
+ auto out_multi_ptr = sycl::address_space_cast<
190
+ sycl::access ::address_space::global_space,
191
+ sycl::access ::decorated::yes>(&out[base + it * sgSize]);
192
+
193
+ arg_vec = sg.load <vec_sz>(in_multi_ptr);
194
194
#pragma unroll
195
195
for (std::uint8_t k = 0 ; k < vec_sz; ++k) {
196
196
res_vec[k] = op (arg_vec[k]);
197
197
}
198
- sg.store <vec_sz>(out_ptrT (&out[base + it * sgSize]),
199
- res_vec);
198
+ sg.store <vec_sz>(out_multi_ptr, res_vec);
200
199
}
201
200
}
202
201
else {
@@ -365,28 +364,26 @@ struct BinaryContigFunctor
365
364
366
365
if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
367
366
(sgSize == maxsgSize)) {
368
- using in_ptrT1 =
369
- sycl::multi_ptr<const argT1,
370
- sycl::access ::address_space::global_space>;
371
- using in_ptrT2 =
372
- sycl::multi_ptr<const argT2,
373
- sycl::access ::address_space::global_space>;
374
- using out_ptrT =
375
- sycl::multi_ptr<resT,
376
- sycl::access ::address_space::global_space>;
377
367
sycl::vec<argT1, vec_sz> arg1_vec;
378
368
sycl::vec<argT2, vec_sz> arg2_vec;
379
369
sycl::vec<resT, vec_sz> res_vec;
380
370
381
371
#pragma unroll
382
372
for (std::uint8_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
383
- arg1_vec =
384
- sg.load <vec_sz>(in_ptrT1 (&in1[base + it * sgSize]));
385
- arg2_vec =
386
- sg.load <vec_sz>(in_ptrT2 (&in2[base + it * sgSize]));
373
+ auto in1_multi_ptr = sycl::address_space_cast<
374
+ sycl::access ::address_space::global_space,
375
+ sycl::access ::decorated::yes>(&in1[base + it * sgSize]);
376
+ auto in2_multi_ptr = sycl::address_space_cast<
377
+ sycl::access ::address_space::global_space,
378
+ sycl::access ::decorated::yes>(&in2[base + it * sgSize]);
379
+ auto out_multi_ptr = sycl::address_space_cast<
380
+ sycl::access ::address_space::global_space,
381
+ sycl::access ::decorated::yes>(&out[base + it * sgSize]);
382
+
383
+ arg1_vec = sg.load <vec_sz>(in1_multi_ptr);
384
+ arg2_vec = sg.load <vec_sz>(in2_multi_ptr);
387
385
res_vec = op (arg1_vec, arg2_vec);
388
- sg.store <vec_sz>(out_ptrT (&out[base + it * sgSize]),
389
- res_vec);
386
+ sg.store <vec_sz>(out_multi_ptr, res_vec);
390
387
}
391
388
}
392
389
else {
@@ -407,32 +404,30 @@ struct BinaryContigFunctor
407
404
408
405
if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
409
406
(sgSize == maxsgSize)) {
410
- using in_ptrT1 =
411
- sycl::multi_ptr<const argT1,
412
- sycl::access ::address_space::global_space>;
413
- using in_ptrT2 =
414
- sycl::multi_ptr<const argT2,
415
- sycl::access ::address_space::global_space>;
416
- using out_ptrT =
417
- sycl::multi_ptr<resT,
418
- sycl::access ::address_space::global_space>;
419
407
sycl::vec<argT1, vec_sz> arg1_vec;
420
408
sycl::vec<argT2, vec_sz> arg2_vec;
421
409
sycl::vec<resT, vec_sz> res_vec;
422
410
423
411
#pragma unroll
424
412
for (std::uint8_t it = 0 ; it < n_vecs * vec_sz; it += vec_sz) {
425
- arg1_vec =
426
- sg.load <vec_sz>(in_ptrT1 (&in1[base + it * sgSize]));
427
- arg2_vec =
428
- sg.load <vec_sz>(in_ptrT2 (&in2[base + it * sgSize]));
413
+ auto in1_multi_ptr = sycl::address_space_cast<
414
+ sycl::access ::address_space::global_space,
415
+ sycl::access ::decorated::yes>(&in1[base + it * sgSize]);
416
+ auto in2_multi_ptr = sycl::address_space_cast<
417
+ sycl::access ::address_space::global_space,
418
+ sycl::access ::decorated::yes>(&in2[base + it * sgSize]);
419
+ auto out_multi_ptr = sycl::address_space_cast<
420
+ sycl::access ::address_space::global_space,
421
+ sycl::access ::decorated::yes>(&out[base + it * sgSize]);
422
+
423
+ arg1_vec = sg.load <vec_sz>(in1_multi_ptr);
424
+ arg2_vec = sg.load <vec_sz>(in2_multi_ptr);
429
425
#pragma unroll
430
426
for (std::uint8_t vec_id = 0 ; vec_id < vec_sz; ++vec_id) {
431
427
res_vec[vec_id] =
432
428
op (arg1_vec[vec_id], arg2_vec[vec_id]);
433
429
}
434
- sg.store <vec_sz>(out_ptrT (&out[base + it * sgSize]),
435
- res_vec);
430
+ sg.store <vec_sz>(out_multi_ptr, res_vec);
436
431
}
437
432
}
438
433
else {
@@ -530,22 +525,24 @@ struct BinaryContigMatrixContigRowBroadcastingFunctor
530
525
size_t base = gid - sg.get_local_id ()[0 ];
531
526
532
527
if (base + sgSize < n_elems) {
533
- using in_ptrT1 =
534
- sycl::multi_ptr<const argT1,
535
- sycl::access ::address_space::global_space>;
536
- using in_ptrT2 =
537
- sycl::multi_ptr<const argT2,
538
- sycl::access ::address_space::global_space>;
539
- using res_ptrT =
540
- sycl::multi_ptr<resT,
541
- sycl::access ::address_space::global_space>;
542
-
543
- const argT1 mat_el = sg.load (in_ptrT1 (&mat[base]));
544
- const argT2 vec_el = sg.load (in_ptrT2 (&padded_vec[base % n1]));
528
+ auto in1_multi_ptr = sycl::address_space_cast<
529
+ sycl::access ::address_space::global_space,
530
+ sycl::access ::decorated::yes>(&mat[base]);
531
+
532
+ auto in2_multi_ptr = sycl::address_space_cast<
533
+ sycl::access ::address_space::global_space,
534
+ sycl::access ::decorated::yes>(&padded_vec[base % n1]);
535
+
536
+ auto out_multi_ptr = sycl::address_space_cast<
537
+ sycl::access ::address_space::global_space,
538
+ sycl::access ::decorated::yes>(&res[base]);
539
+
540
+ const argT1 mat_el = sg.load (in1_multi_ptr);
541
+ const argT2 vec_el = sg.load (in2_multi_ptr);
545
542
546
543
resT res_el = op (mat_el, vec_el);
547
544
548
- sg.store (res_ptrT (&res[base]) , res_el);
545
+ sg.store (out_multi_ptr , res_el);
549
546
}
550
547
else {
551
548
for (size_t k = base + sg.get_local_id ()[0 ]; k < n_elems;
@@ -592,22 +589,24 @@ struct BinaryContigRowContigMatrixBroadcastingFunctor
592
589
size_t base = gid - sg.get_local_id ()[0 ];
593
590
594
591
if (base + sgSize < n_elems) {
595
- using in_ptrT1 =
596
- sycl::multi_ptr<const argT1,
597
- sycl::access ::address_space::global_space>;
598
- using in_ptrT2 =
599
- sycl::multi_ptr<const argT2,
600
- sycl::access ::address_space::global_space>;
601
- using res_ptrT =
602
- sycl::multi_ptr<resT,
603
- sycl::access ::address_space::global_space>;
604
-
605
- const argT2 mat_el = sg.load (in_ptrT2 (&mat[base]));
606
- const argT1 vec_el = sg.load (in_ptrT1 (&padded_vec[base % n1]));
592
+ auto in1_multi_ptr = sycl::address_space_cast<
593
+ sycl::access ::address_space::global_space,
594
+ sycl::access ::decorated::yes>(&padded_vec[base % n1]);
595
+
596
+ auto in2_multi_ptr = sycl::address_space_cast<
597
+ sycl::access ::address_space::global_space,
598
+ sycl::access ::decorated::yes>(&mat[base]);
599
+
600
+ auto out_multi_ptr = sycl::address_space_cast<
601
+ sycl::access ::address_space::global_space,
602
+ sycl::access ::decorated::yes>(&res[base]);
603
+
604
+ const argT2 mat_el = sg.load (in2_multi_ptr);
605
+ const argT1 vec_el = sg.load (in1_multi_ptr);
607
606
608
607
resT res_el = op (vec_el, mat_el);
609
608
610
- sg.store (res_ptrT (&res[base]) , res_el);
609
+ sg.store (out_multi_ptr , res_el);
611
610
}
612
611
else {
613
612
for (size_t k = base + sg.get_local_id ()[0 ]; k < n_elems;
0 commit comments