Skip to content

Commit 8eab04b

Browse files
Merge pull request #1366 from IntelPython/create-multi_ptr-per-sycl-2020-standard
Conversion from raw to multi_ptr should be done with address_space_cast
2 parents d85e130 + f7eee1e commit 8eab04b

File tree

4 files changed

+148
-145
lines changed

4 files changed

+148
-145
lines changed

dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -244,25 +244,26 @@ class ContigCopyFunctor
244244

245245
if (base + n_vecs * vec_sz * sgSize < nelems &&
246246
sgSize == max_sgSize) {
247-
using src_ptrT =
248-
sycl::multi_ptr<const srcT,
249-
sycl::access::address_space::global_space>;
250-
using dst_ptrT =
251-
sycl::multi_ptr<dstT,
252-
sycl::access::address_space::global_space>;
253247
sycl::vec<srcT, vec_sz> src_vec;
254248
sycl::vec<dstT, vec_sz> dst_vec;
255249

256250
#pragma unroll
257251
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
258-
src_vec =
259-
sg.load<vec_sz>(src_ptrT(&src_p[base + it * sgSize]));
252+
auto src_multi_ptr = sycl::address_space_cast<
253+
sycl::access::address_space::global_space,
254+
sycl::access::decorated::yes>(
255+
&src_p[base + it * sgSize]);
256+
auto dst_multi_ptr = sycl::address_space_cast<
257+
sycl::access::address_space::global_space,
258+
sycl::access::decorated::yes>(
259+
&dst_p[base + it * sgSize]);
260+
261+
src_vec = sg.load<vec_sz>(src_multi_ptr);
260262
#pragma unroll
261263
for (std::uint8_t k = 0; k < vec_sz; k++) {
262264
dst_vec[k] = fn(src_vec[k]);
263265
}
264-
sg.store<vec_sz>(dst_ptrT(&dst_p[base + it * sgSize]),
265-
dst_vec);
266+
sg.store<vec_sz>(dst_multi_ptr, dst_vec);
266267
}
267268
}
268269
else {

dpctl/tensor/libtensor/include/kernels/elementwise_functions/common.hpp

Lines changed: 88 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,6 @@ struct UnaryContigFunctor
6565
if constexpr (UnaryOperatorT::is_constant::value) {
6666
// value of operator is known to be a known constant
6767
constexpr resT const_val = UnaryOperatorT::constant_value;
68-
using out_ptrT =
69-
sycl::multi_ptr<resT,
70-
sycl::access::address_space::global_space>;
7168

7269
auto sg = ndit.get_sub_group();
7370
std::uint8_t sgSize = sg.get_local_range()[0];
@@ -80,8 +77,11 @@ struct UnaryContigFunctor
8077
sycl::vec<resT, vec_sz> res_vec(const_val);
8178
#pragma unroll
8279
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
83-
sg.store<vec_sz>(out_ptrT(&out[base + it * sgSize]),
84-
res_vec);
80+
auto out_multi_ptr = sycl::address_space_cast<
81+
sycl::access::address_space::global_space,
82+
sycl::access::decorated::yes>(&out[base + it * sgSize]);
83+
84+
sg.store<vec_sz>(out_multi_ptr, res_vec);
8585
}
8686
}
8787
else {
@@ -94,13 +94,6 @@ struct UnaryContigFunctor
9494
else if constexpr (UnaryOperatorT::supports_sg_loadstore::value &&
9595
UnaryOperatorT::supports_vec::value)
9696
{
97-
using in_ptrT =
98-
sycl::multi_ptr<const argT,
99-
sycl::access::address_space::global_space>;
100-
using out_ptrT =
101-
sycl::multi_ptr<resT,
102-
sycl::access::address_space::global_space>;
103-
10497
auto sg = ndit.get_sub_group();
10598
std::uint16_t sgSize = sg.get_local_range()[0];
10699
std::uint16_t max_sgSize = sg.get_max_local_range()[0];
@@ -113,10 +106,16 @@ struct UnaryContigFunctor
113106

114107
#pragma unroll
115108
for (std::uint16_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
116-
x = sg.load<vec_sz>(in_ptrT(&in[base + it * sgSize]));
109+
auto in_multi_ptr = sycl::address_space_cast<
110+
sycl::access::address_space::global_space,
111+
sycl::access::decorated::yes>(&in[base + it * sgSize]);
112+
auto out_multi_ptr = sycl::address_space_cast<
113+
sycl::access::address_space::global_space,
114+
sycl::access::decorated::yes>(&out[base + it * sgSize]);
115+
116+
x = sg.load<vec_sz>(in_multi_ptr);
117117
sycl::vec<resT, vec_sz> res_vec = op(x);
118-
sg.store<vec_sz>(out_ptrT(&out[base + it * sgSize]),
119-
res_vec);
118+
sg.store<vec_sz>(out_multi_ptr, res_vec);
120119
}
121120
}
122121
else {
@@ -141,23 +140,23 @@ struct UnaryContigFunctor
141140

142141
if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
143142
(maxsgSize == sgSize)) {
144-
using in_ptrT =
145-
sycl::multi_ptr<const argT,
146-
sycl::access::address_space::global_space>;
147-
using out_ptrT =
148-
sycl::multi_ptr<resT,
149-
sycl::access::address_space::global_space>;
150143
sycl::vec<argT, vec_sz> arg_vec;
151144

152145
#pragma unroll
153146
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
154-
arg_vec = sg.load<vec_sz>(in_ptrT(&in[base + it * sgSize]));
147+
auto in_multi_ptr = sycl::address_space_cast<
148+
sycl::access::address_space::global_space,
149+
sycl::access::decorated::yes>(&in[base + it * sgSize]);
150+
auto out_multi_ptr = sycl::address_space_cast<
151+
sycl::access::address_space::global_space,
152+
sycl::access::decorated::yes>(&out[base + it * sgSize]);
153+
154+
arg_vec = sg.load<vec_sz>(in_multi_ptr);
155155
#pragma unroll
156156
for (std::uint8_t k = 0; k < vec_sz; ++k) {
157157
arg_vec[k] = op(arg_vec[k]);
158158
}
159-
sg.store<vec_sz>(out_ptrT(&out[base + it * sgSize]),
160-
arg_vec);
159+
sg.store<vec_sz>(out_multi_ptr, arg_vec);
161160
}
162161
}
163162
else {
@@ -179,24 +178,24 @@ struct UnaryContigFunctor
179178

180179
if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
181180
(maxsgSize == sgSize)) {
182-
using in_ptrT =
183-
sycl::multi_ptr<const argT,
184-
sycl::access::address_space::global_space>;
185-
using out_ptrT =
186-
sycl::multi_ptr<resT,
187-
sycl::access::address_space::global_space>;
188181
sycl::vec<argT, vec_sz> arg_vec;
189182
sycl::vec<resT, vec_sz> res_vec;
190183

191184
#pragma unroll
192185
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
193-
arg_vec = sg.load<vec_sz>(in_ptrT(&in[base + it * sgSize]));
186+
auto in_multi_ptr = sycl::address_space_cast<
187+
sycl::access::address_space::global_space,
188+
sycl::access::decorated::yes>(&in[base + it * sgSize]);
189+
auto out_multi_ptr = sycl::address_space_cast<
190+
sycl::access::address_space::global_space,
191+
sycl::access::decorated::yes>(&out[base + it * sgSize]);
192+
193+
arg_vec = sg.load<vec_sz>(in_multi_ptr);
194194
#pragma unroll
195195
for (std::uint8_t k = 0; k < vec_sz; ++k) {
196196
res_vec[k] = op(arg_vec[k]);
197197
}
198-
sg.store<vec_sz>(out_ptrT(&out[base + it * sgSize]),
199-
res_vec);
198+
sg.store<vec_sz>(out_multi_ptr, res_vec);
200199
}
201200
}
202201
else {
@@ -365,28 +364,26 @@ struct BinaryContigFunctor
365364

366365
if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
367366
(sgSize == maxsgSize)) {
368-
using in_ptrT1 =
369-
sycl::multi_ptr<const argT1,
370-
sycl::access::address_space::global_space>;
371-
using in_ptrT2 =
372-
sycl::multi_ptr<const argT2,
373-
sycl::access::address_space::global_space>;
374-
using out_ptrT =
375-
sycl::multi_ptr<resT,
376-
sycl::access::address_space::global_space>;
377367
sycl::vec<argT1, vec_sz> arg1_vec;
378368
sycl::vec<argT2, vec_sz> arg2_vec;
379369
sycl::vec<resT, vec_sz> res_vec;
380370

381371
#pragma unroll
382372
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
383-
arg1_vec =
384-
sg.load<vec_sz>(in_ptrT1(&in1[base + it * sgSize]));
385-
arg2_vec =
386-
sg.load<vec_sz>(in_ptrT2(&in2[base + it * sgSize]));
373+
auto in1_multi_ptr = sycl::address_space_cast<
374+
sycl::access::address_space::global_space,
375+
sycl::access::decorated::yes>(&in1[base + it * sgSize]);
376+
auto in2_multi_ptr = sycl::address_space_cast<
377+
sycl::access::address_space::global_space,
378+
sycl::access::decorated::yes>(&in2[base + it * sgSize]);
379+
auto out_multi_ptr = sycl::address_space_cast<
380+
sycl::access::address_space::global_space,
381+
sycl::access::decorated::yes>(&out[base + it * sgSize]);
382+
383+
arg1_vec = sg.load<vec_sz>(in1_multi_ptr);
384+
arg2_vec = sg.load<vec_sz>(in2_multi_ptr);
387385
res_vec = op(arg1_vec, arg2_vec);
388-
sg.store<vec_sz>(out_ptrT(&out[base + it * sgSize]),
389-
res_vec);
386+
sg.store<vec_sz>(out_multi_ptr, res_vec);
390387
}
391388
}
392389
else {
@@ -407,32 +404,30 @@ struct BinaryContigFunctor
407404

408405
if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
409406
(sgSize == maxsgSize)) {
410-
using in_ptrT1 =
411-
sycl::multi_ptr<const argT1,
412-
sycl::access::address_space::global_space>;
413-
using in_ptrT2 =
414-
sycl::multi_ptr<const argT2,
415-
sycl::access::address_space::global_space>;
416-
using out_ptrT =
417-
sycl::multi_ptr<resT,
418-
sycl::access::address_space::global_space>;
419407
sycl::vec<argT1, vec_sz> arg1_vec;
420408
sycl::vec<argT2, vec_sz> arg2_vec;
421409
sycl::vec<resT, vec_sz> res_vec;
422410

423411
#pragma unroll
424412
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
425-
arg1_vec =
426-
sg.load<vec_sz>(in_ptrT1(&in1[base + it * sgSize]));
427-
arg2_vec =
428-
sg.load<vec_sz>(in_ptrT2(&in2[base + it * sgSize]));
413+
auto in1_multi_ptr = sycl::address_space_cast<
414+
sycl::access::address_space::global_space,
415+
sycl::access::decorated::yes>(&in1[base + it * sgSize]);
416+
auto in2_multi_ptr = sycl::address_space_cast<
417+
sycl::access::address_space::global_space,
418+
sycl::access::decorated::yes>(&in2[base + it * sgSize]);
419+
auto out_multi_ptr = sycl::address_space_cast<
420+
sycl::access::address_space::global_space,
421+
sycl::access::decorated::yes>(&out[base + it * sgSize]);
422+
423+
arg1_vec = sg.load<vec_sz>(in1_multi_ptr);
424+
arg2_vec = sg.load<vec_sz>(in2_multi_ptr);
429425
#pragma unroll
430426
for (std::uint8_t vec_id = 0; vec_id < vec_sz; ++vec_id) {
431427
res_vec[vec_id] =
432428
op(arg1_vec[vec_id], arg2_vec[vec_id]);
433429
}
434-
sg.store<vec_sz>(out_ptrT(&out[base + it * sgSize]),
435-
res_vec);
430+
sg.store<vec_sz>(out_multi_ptr, res_vec);
436431
}
437432
}
438433
else {
@@ -530,22 +525,24 @@ struct BinaryContigMatrixContigRowBroadcastingFunctor
530525
size_t base = gid - sg.get_local_id()[0];
531526

532527
if (base + sgSize < n_elems) {
533-
using in_ptrT1 =
534-
sycl::multi_ptr<const argT1,
535-
sycl::access::address_space::global_space>;
536-
using in_ptrT2 =
537-
sycl::multi_ptr<const argT2,
538-
sycl::access::address_space::global_space>;
539-
using res_ptrT =
540-
sycl::multi_ptr<resT,
541-
sycl::access::address_space::global_space>;
542-
543-
const argT1 mat_el = sg.load(in_ptrT1(&mat[base]));
544-
const argT2 vec_el = sg.load(in_ptrT2(&padded_vec[base % n1]));
528+
auto in1_multi_ptr = sycl::address_space_cast<
529+
sycl::access::address_space::global_space,
530+
sycl::access::decorated::yes>(&mat[base]);
531+
532+
auto in2_multi_ptr = sycl::address_space_cast<
533+
sycl::access::address_space::global_space,
534+
sycl::access::decorated::yes>(&padded_vec[base % n1]);
535+
536+
auto out_multi_ptr = sycl::address_space_cast<
537+
sycl::access::address_space::global_space,
538+
sycl::access::decorated::yes>(&res[base]);
539+
540+
const argT1 mat_el = sg.load(in1_multi_ptr);
541+
const argT2 vec_el = sg.load(in2_multi_ptr);
545542

546543
resT res_el = op(mat_el, vec_el);
547544

548-
sg.store(res_ptrT(&res[base]), res_el);
545+
sg.store(out_multi_ptr, res_el);
549546
}
550547
else {
551548
for (size_t k = base + sg.get_local_id()[0]; k < n_elems;
@@ -592,22 +589,24 @@ struct BinaryContigRowContigMatrixBroadcastingFunctor
592589
size_t base = gid - sg.get_local_id()[0];
593590

594591
if (base + sgSize < n_elems) {
595-
using in_ptrT1 =
596-
sycl::multi_ptr<const argT1,
597-
sycl::access::address_space::global_space>;
598-
using in_ptrT2 =
599-
sycl::multi_ptr<const argT2,
600-
sycl::access::address_space::global_space>;
601-
using res_ptrT =
602-
sycl::multi_ptr<resT,
603-
sycl::access::address_space::global_space>;
604-
605-
const argT2 mat_el = sg.load(in_ptrT2(&mat[base]));
606-
const argT1 vec_el = sg.load(in_ptrT1(&padded_vec[base % n1]));
592+
auto in1_multi_ptr = sycl::address_space_cast<
593+
sycl::access::address_space::global_space,
594+
sycl::access::decorated::yes>(&padded_vec[base % n1]);
595+
596+
auto in2_multi_ptr = sycl::address_space_cast<
597+
sycl::access::address_space::global_space,
598+
sycl::access::decorated::yes>(&mat[base]);
599+
600+
auto out_multi_ptr = sycl::address_space_cast<
601+
sycl::access::address_space::global_space,
602+
sycl::access::decorated::yes>(&res[base]);
603+
604+
const argT2 mat_el = sg.load(in2_multi_ptr);
605+
const argT1 vec_el = sg.load(in1_multi_ptr);
607606

608607
resT res_el = op(vec_el, mat_el);
609608

610-
sg.store(res_ptrT(&res[base]), res_el);
609+
sg.store(out_multi_ptr, res_el);
611610
}
612611
else {
613612
for (size_t k = base + sg.get_local_id()[0]; k < n_elems;

0 commit comments

Comments
 (0)