Skip to content

Commit 96293fd

Browse files
Fully enable usm_ndarray in-place arithmetic operators (#1352)
* Binary elementwise functions can now act on any input in-place - A temporary will be allocated as necessary (i.e., when arrays overlap, are not going to be cast, and are not the same logical arrays) - Uses dedicated in-place kernels where they are implemented - Now called directly by Python operators - Removes _inplace method of BinaryElementwiseFunc class - Removes _find_inplace_dtype function * Tests for new out parameter behavior for add * Broadcasting made conditional in binary functions where memory overlap is possible - Broadcasting can change the values of strides without changing array shape * Changed exception types raised Use ExecutionPlacementError for CFD violations. Use ValueError is types of input are as expected, but values are not as expected. * Adding tests to improve coverage Removed tests expecting error raised in case of overlapping inputs. Added tests guided by coverage report. * Removed provably unreachable branches in _resolve_weak_types Since o1_dtype_kind_num > o2_dtype_kind_num, o1 can be not be weak boolean type, since it has the lowest kind number in the hierarchy. * All in-place operators now use call operator of BinaryElementwiseFunc * Removed some redundant and obsolete tests - Removed from test_floor_ceil_trunc, test_hyperbolic, test_trigonometric, and test_logaddexp - These tests would fail on GPU but never run on CPU, and therefore were not impacting the coverage - These tests focused on aspects of the BinaryElementwiseFunc class rather than the behavior of the operator --------- Co-authored-by: Oleksandr Pavlyk <[email protected]>
1 parent 852f4b1 commit 96293fd

9 files changed

+218
-486
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 114 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
_acceptance_fn_default,
3232
_find_buf_dtype,
3333
_find_buf_dtype2,
34-
_find_inplace_dtype,
3534
_to_device_supported_dtype,
3635
)
3736

@@ -79,8 +78,8 @@ def __call__(self, x, out=None, order="K"):
7978
)
8079

8180
if out.shape != x.shape:
82-
raise TypeError(
83-
"The shape of input and output arrays are inconsistent."
81+
raise ValueError(
82+
"The shape of input and output arrays are inconsistent. "
8483
f"Expected output shape is {x.shape}, got {out.shape}"
8584
)
8685

@@ -104,7 +103,7 @@ def __call__(self, x, out=None, order="K"):
104103
dpctl.utils.get_execution_queue((x.sycl_queue, out.sycl_queue))
105104
is None
106105
):
107-
raise TypeError(
106+
raise ExecutionPlacementError(
108107
"Input and output allocation queues are not compatible"
109108
)
110109

@@ -302,8 +301,6 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
302301
o1_kind_num = _weak_type_num_kind(o1_dtype)
303302
o2_kind_num = _strong_dtype_num_kind(o2_dtype)
304303
if o1_kind_num > o2_kind_num:
305-
if isinstance(o1_dtype, WeakBooleanType):
306-
return dpt.bool, o2_dtype
307304
if isinstance(o1_dtype, WeakIntegralType):
308305
return dpt.int64, o2_dtype
309306
if isinstance(o1_dtype, WeakComplexType):
@@ -323,8 +320,6 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
323320
o1_kind_num = _strong_dtype_num_kind(o1_dtype)
324321
o2_kind_num = _weak_type_num_kind(o2_dtype)
325322
if o2_kind_num > o1_kind_num:
326-
if isinstance(o2_dtype, WeakBooleanType):
327-
return o1_dtype, dpt.bool
328323
if isinstance(o2_dtype, WeakIntegralType):
329324
return o1_dtype, dpt.int64
330325
if isinstance(o2_dtype, WeakComplexType):
@@ -383,14 +378,6 @@ def __repr__(self):
383378
return f"<{self.__name__} '{self.name_}'>"
384379

385380
def __call__(self, o1, o2, out=None, order="K"):
386-
# FIXME: replace with check against base array
387-
# when views can be identified
388-
if self.binary_inplace_fn_:
389-
if o1 is out:
390-
return self._inplace(o1, o2)
391-
elif o2 is out:
392-
return self._inplace(o2, o1)
393-
394381
if order not in ["K", "C", "F", "A"]:
395382
order = "K"
396383
q1, o1_usm_type = _get_queue_usm_type(o1)
@@ -472,31 +459,90 @@ def __call__(self, o1, o2, out=None, order="K"):
472459
"supported types according to the casting rule ''safe''."
473460
)
474461

462+
orig_out = out
475463
if out is not None:
476464
if not isinstance(out, dpt.usm_ndarray):
477465
raise TypeError(
478466
f"output array must be of usm_ndarray type, got {type(out)}"
479467
)
480468

481469
if out.shape != res_shape:
482-
raise TypeError(
483-
"The shape of input and output arrays are inconsistent."
470+
raise ValueError(
471+
"The shape of input and output arrays are inconsistent. "
484472
f"Expected output shape is {o1_shape}, got {out.shape}"
485473
)
486474

487-
if ti._array_overlap(o1, out) or ti._array_overlap(o2, out):
488-
raise TypeError("Input and output arrays have memory overlap")
475+
if res_dt != out.dtype:
476+
raise TypeError(
477+
f"Output array of type {res_dt} is needed,"
478+
f"got {out.dtype}"
479+
)
489480

490481
if (
491-
dpctl.utils.get_execution_queue(
492-
(o1.sycl_queue, o2.sycl_queue, out.sycl_queue)
493-
)
482+
dpctl.utils.get_execution_queue((exec_q, out.sycl_queue))
494483
is None
495484
):
496-
raise TypeError(
485+
raise ExecutionPlacementError(
497486
"Input and output allocation queues are not compatible"
498487
)
499488

489+
if isinstance(o1, dpt.usm_ndarray):
490+
if ti._array_overlap(o1, out) and buf1_dt is None:
491+
if not ti._same_logical_tensors(o1, out):
492+
out = dpt.empty_like(out)
493+
elif self.binary_inplace_fn_ is not None:
494+
# if there is a dedicated in-place kernel
495+
# it can be called here, otherwise continues
496+
if isinstance(o2, dpt.usm_ndarray):
497+
src2 = o2
498+
if (
499+
ti._array_overlap(o2, out)
500+
and not ti._same_logical_tensors(o2, out)
501+
and buf2_dt is None
502+
):
503+
buf2_dt = o2_dtype
504+
else:
505+
src2 = dpt.asarray(
506+
o2, dtype=o2_dtype, sycl_queue=exec_q
507+
)
508+
if buf2_dt is None:
509+
if src2.shape != res_shape:
510+
src2 = dpt.broadcast_to(src2, res_shape)
511+
ht_, _ = self.binary_inplace_fn_(
512+
lhs=o1, rhs=src2, sycl_queue=exec_q
513+
)
514+
ht_.wait()
515+
else:
516+
buf2 = dpt.empty_like(src2, dtype=buf2_dt)
517+
(
518+
ht_copy_ev,
519+
copy_ev,
520+
) = ti._copy_usm_ndarray_into_usm_ndarray(
521+
src=src2, dst=buf2, sycl_queue=exec_q
522+
)
523+
524+
buf2 = dpt.broadcast_to(buf2, res_shape)
525+
ht_, _ = self.binary_inplace_fn_(
526+
lhs=o1,
527+
rhs=buf2,
528+
sycl_queue=exec_q,
529+
depends=[copy_ev],
530+
)
531+
ht_copy_ev.wait()
532+
ht_.wait()
533+
534+
return out
535+
536+
if isinstance(o2, dpt.usm_ndarray):
537+
if (
538+
ti._array_overlap(o2, out)
539+
and not ti._same_logical_tensors(o2, out)
540+
and buf2_dt is None
541+
):
542+
# should not reach if out is reallocated
543+
# after being checked against o1
544+
out = dpt.empty_like(out)
545+
500546
if isinstance(o1, dpt.usm_ndarray):
501547
src1 = o1
502548
else:
@@ -532,19 +578,24 @@ def __call__(self, o1, o2, out=None, order="K"):
532578
sycl_queue=exec_q,
533579
order=order,
534580
)
535-
else:
536-
if res_dt != out.dtype:
537-
raise TypeError(
538-
f"Output array of type {res_dt} is needed,"
539-
f"got {out.dtype}"
540-
)
541-
542-
src1 = dpt.broadcast_to(src1, res_shape)
543-
src2 = dpt.broadcast_to(src2, res_shape)
544-
ht_, _ = self.binary_fn_(
581+
if src1.shape != res_shape:
582+
src1 = dpt.broadcast_to(src1, res_shape)
583+
if src2.shape != res_shape:
584+
src2 = dpt.broadcast_to(src2, res_shape)
585+
ht_binary_ev, binary_ev = self.binary_fn_(
545586
src1=src1, src2=src2, dst=out, sycl_queue=exec_q
546587
)
547-
ht_.wait()
588+
if not (orig_out is None or orig_out is out):
589+
# Copy the out data from temporary buffer to original memory
590+
ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
591+
src=out,
592+
dst=orig_out,
593+
sycl_queue=exec_q,
594+
depends=[binary_ev],
595+
)
596+
ht_copy_out_ev.wait()
597+
out = orig_out
598+
ht_binary_ev.wait()
548599
return out
549600
elif buf1_dt is None:
550601
if order == "K":
@@ -575,18 +626,28 @@ def __call__(self, o1, o2, out=None, order="K"):
575626
f"Output array of type {res_dt} is needed,"
576627
f"got {out.dtype}"
577628
)
578-
579-
src1 = dpt.broadcast_to(src1, res_shape)
629+
if src1.shape != res_shape:
630+
src1 = dpt.broadcast_to(src1, res_shape)
580631
buf2 = dpt.broadcast_to(buf2, res_shape)
581-
ht_, _ = self.binary_fn_(
632+
ht_binary_ev, binary_ev = self.binary_fn_(
582633
src1=src1,
583634
src2=buf2,
584635
dst=out,
585636
sycl_queue=exec_q,
586637
depends=[copy_ev],
587638
)
639+
if not (orig_out is None or orig_out is out):
640+
# Copy the out data from temporary buffer to original memory
641+
ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
642+
src=out,
643+
dst=orig_out,
644+
sycl_queue=exec_q,
645+
depends=[binary_ev],
646+
)
647+
ht_copy_out_ev.wait()
648+
out = orig_out
588649
ht_copy_ev.wait()
589-
ht_.wait()
650+
ht_binary_ev.wait()
590651
return out
591652
elif buf2_dt is None:
592653
if order == "K":
@@ -611,24 +672,29 @@ def __call__(self, o1, o2, out=None, order="K"):
611672
sycl_queue=exec_q,
612673
order=order,
613674
)
614-
else:
615-
if res_dt != out.dtype:
616-
raise TypeError(
617-
f"Output array of type {res_dt} is needed,"
618-
f"got {out.dtype}"
619-
)
620675

621676
buf1 = dpt.broadcast_to(buf1, res_shape)
622-
src2 = dpt.broadcast_to(src2, res_shape)
623-
ht_, _ = self.binary_fn_(
677+
if src2.shape != res_shape:
678+
src2 = dpt.broadcast_to(src2, res_shape)
679+
ht_binary_ev, binary_ev = self.binary_fn_(
624680
src1=buf1,
625681
src2=src2,
626682
dst=out,
627683
sycl_queue=exec_q,
628684
depends=[copy_ev],
629685
)
686+
if not (orig_out is None or orig_out is out):
687+
# Copy the out data from temporary buffer to original memory
688+
ht_copy_out_ev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
689+
src=out,
690+
dst=orig_out,
691+
sycl_queue=exec_q,
692+
depends=[binary_ev],
693+
)
694+
ht_copy_out_ev.wait()
695+
out = orig_out
630696
ht_copy_ev.wait()
631-
ht_.wait()
697+
ht_binary_ev.wait()
632698
return out
633699

634700
if order in ["K", "A"]:
@@ -665,11 +731,6 @@ def __call__(self, o1, o2, out=None, order="K"):
665731
sycl_queue=exec_q,
666732
order=order,
667733
)
668-
else:
669-
if res_dt != out.dtype:
670-
raise TypeError(
671-
f"Output array of type {res_dt} is needed, got {out.dtype}"
672-
)
673734

674735
buf1 = dpt.broadcast_to(buf1, res_shape)
675736
buf2 = dpt.broadcast_to(buf2, res_shape)
@@ -682,116 +743,3 @@ def __call__(self, o1, o2, out=None, order="K"):
682743
)
683744
dpctl.SyclEvent.wait_for([ht_copy1_ev, ht_copy2_ev, ht_])
684745
return out
685-
686-
def _inplace(self, lhs, val):
687-
if self.binary_inplace_fn_ is None:
688-
raise ValueError(
689-
f"In-place operation not supported for ufunc '{self.name_}'"
690-
)
691-
if not isinstance(lhs, dpt.usm_ndarray):
692-
raise TypeError(
693-
f"Expected dpctl.tensor.usm_ndarray, got {type(lhs)}"
694-
)
695-
q1, lhs_usm_type = _get_queue_usm_type(lhs)
696-
q2, val_usm_type = _get_queue_usm_type(val)
697-
if q2 is None:
698-
exec_q = q1
699-
usm_type = lhs_usm_type
700-
else:
701-
exec_q = dpctl.utils.get_execution_queue((q1, q2))
702-
if exec_q is None:
703-
raise ExecutionPlacementError(
704-
"Execution placement can not be unambiguously inferred "
705-
"from input arguments."
706-
)
707-
usm_type = dpctl.utils.get_coerced_usm_type(
708-
(
709-
lhs_usm_type,
710-
val_usm_type,
711-
)
712-
)
713-
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
714-
lhs_shape = _get_shape(lhs)
715-
val_shape = _get_shape(val)
716-
if not all(
717-
isinstance(s, (tuple, list))
718-
for s in (
719-
lhs_shape,
720-
val_shape,
721-
)
722-
):
723-
raise TypeError(
724-
"Shape of arguments can not be inferred. "
725-
"Arguments are expected to be "
726-
"lists, tuples, or both"
727-
)
728-
try:
729-
res_shape = _broadcast_shape_impl(
730-
[
731-
lhs_shape,
732-
val_shape,
733-
]
734-
)
735-
except ValueError:
736-
raise ValueError(
737-
"operands could not be broadcast together with shapes "
738-
f"{lhs_shape} and {val_shape}"
739-
)
740-
if res_shape != lhs_shape:
741-
raise ValueError(
742-
f"output shape {lhs_shape} does not match "
743-
f"broadcast shape {res_shape}"
744-
)
745-
sycl_dev = exec_q.sycl_device
746-
lhs_dtype = lhs.dtype
747-
val_dtype = _get_dtype(val, sycl_dev)
748-
if not _validate_dtype(val_dtype):
749-
raise ValueError("Input operand of unsupported type")
750-
751-
lhs_dtype, val_dtype = _resolve_weak_types(
752-
lhs_dtype, val_dtype, sycl_dev
753-
)
754-
755-
buf_dt = _find_inplace_dtype(
756-
lhs_dtype, val_dtype, self.result_type_resolver_fn_, sycl_dev
757-
)
758-
759-
if buf_dt is None:
760-
raise TypeError(
761-
f"In-place '{self.name_}' does not support input types "
762-
f"({lhs_dtype}, {val_dtype}), "
763-
"and the inputs could not be safely coerced to any "
764-
"supported types according to the casting rule ''safe''."
765-
)
766-
767-
if isinstance(val, dpt.usm_ndarray):
768-
rhs = val
769-
overlap = ti._array_overlap(lhs, rhs)
770-
else:
771-
rhs = dpt.asarray(val, dtype=val_dtype, sycl_queue=exec_q)
772-
overlap = False
773-
774-
if buf_dt == val_dtype and overlap is False:
775-
rhs = dpt.broadcast_to(rhs, res_shape)
776-
ht_, _ = self.binary_inplace_fn_(
777-
lhs=lhs, rhs=rhs, sycl_queue=exec_q
778-
)
779-
ht_.wait()
780-
781-
else:
782-
buf = dpt.empty_like(rhs, dtype=buf_dt)
783-
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
784-
src=rhs, dst=buf, sycl_queue=exec_q
785-
)
786-
787-
buf = dpt.broadcast_to(buf, res_shape)
788-
ht_, _ = self.binary_inplace_fn_(
789-
lhs=lhs,
790-
rhs=buf,
791-
sycl_queue=exec_q,
792-
depends=[copy_ev],
793-
)
794-
ht_copy_ev.wait()
795-
ht_.wait()
796-
797-
return lhs

0 commit comments

Comments
 (0)