31
31
_acceptance_fn_default ,
32
32
_find_buf_dtype ,
33
33
_find_buf_dtype2 ,
34
- _find_inplace_dtype ,
35
34
_to_device_supported_dtype ,
36
35
)
37
36
@@ -79,8 +78,8 @@ def __call__(self, x, out=None, order="K"):
79
78
)
80
79
81
80
if out .shape != x .shape :
82
- raise TypeError (
83
- "The shape of input and output arrays are inconsistent."
81
+ raise ValueError (
82
+ "The shape of input and output arrays are inconsistent. "
84
83
f"Expected output shape is { x .shape } , got { out .shape } "
85
84
)
86
85
@@ -104,7 +103,7 @@ def __call__(self, x, out=None, order="K"):
104
103
dpctl .utils .get_execution_queue ((x .sycl_queue , out .sycl_queue ))
105
104
is None
106
105
):
107
- raise TypeError (
106
+ raise ExecutionPlacementError (
108
107
"Input and output allocation queues are not compatible"
109
108
)
110
109
@@ -302,8 +301,6 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
302
301
o1_kind_num = _weak_type_num_kind (o1_dtype )
303
302
o2_kind_num = _strong_dtype_num_kind (o2_dtype )
304
303
if o1_kind_num > o2_kind_num :
305
- if isinstance (o1_dtype , WeakBooleanType ):
306
- return dpt .bool , o2_dtype
307
304
if isinstance (o1_dtype , WeakIntegralType ):
308
305
return dpt .int64 , o2_dtype
309
306
if isinstance (o1_dtype , WeakComplexType ):
@@ -323,8 +320,6 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
323
320
o1_kind_num = _strong_dtype_num_kind (o1_dtype )
324
321
o2_kind_num = _weak_type_num_kind (o2_dtype )
325
322
if o2_kind_num > o1_kind_num :
326
- if isinstance (o2_dtype , WeakBooleanType ):
327
- return o1_dtype , dpt .bool
328
323
if isinstance (o2_dtype , WeakIntegralType ):
329
324
return o1_dtype , dpt .int64
330
325
if isinstance (o2_dtype , WeakComplexType ):
@@ -383,14 +378,6 @@ def __repr__(self):
383
378
return f"<{ self .__name__ } '{ self .name_ } '>"
384
379
385
380
def __call__ (self , o1 , o2 , out = None , order = "K" ):
386
- # FIXME: replace with check against base array
387
- # when views can be identified
388
- if self .binary_inplace_fn_ :
389
- if o1 is out :
390
- return self ._inplace (o1 , o2 )
391
- elif o2 is out :
392
- return self ._inplace (o2 , o1 )
393
-
394
381
if order not in ["K" , "C" , "F" , "A" ]:
395
382
order = "K"
396
383
q1 , o1_usm_type = _get_queue_usm_type (o1 )
@@ -472,31 +459,90 @@ def __call__(self, o1, o2, out=None, order="K"):
472
459
"supported types according to the casting rule ''safe''."
473
460
)
474
461
462
+ orig_out = out
475
463
if out is not None :
476
464
if not isinstance (out , dpt .usm_ndarray ):
477
465
raise TypeError (
478
466
f"output array must be of usm_ndarray type, got { type (out )} "
479
467
)
480
468
481
469
if out .shape != res_shape :
482
- raise TypeError (
483
- "The shape of input and output arrays are inconsistent."
470
+ raise ValueError (
471
+ "The shape of input and output arrays are inconsistent. "
484
472
f"Expected output shape is { o1_shape } , got { out .shape } "
485
473
)
486
474
487
- if ti ._array_overlap (o1 , out ) or ti ._array_overlap (o2 , out ):
488
- raise TypeError ("Input and output arrays have memory overlap" )
475
+ if res_dt != out .dtype :
476
+ raise TypeError (
477
+ f"Output array of type { res_dt } is needed,"
478
+ f"got { out .dtype } "
479
+ )
489
480
490
481
if (
491
- dpctl .utils .get_execution_queue (
492
- (o1 .sycl_queue , o2 .sycl_queue , out .sycl_queue )
493
- )
482
+ dpctl .utils .get_execution_queue ((exec_q , out .sycl_queue ))
494
483
is None
495
484
):
496
- raise TypeError (
485
+ raise ExecutionPlacementError (
497
486
"Input and output allocation queues are not compatible"
498
487
)
499
488
489
+ if isinstance (o1 , dpt .usm_ndarray ):
490
+ if ti ._array_overlap (o1 , out ) and buf1_dt is None :
491
+ if not ti ._same_logical_tensors (o1 , out ):
492
+ out = dpt .empty_like (out )
493
+ elif self .binary_inplace_fn_ is not None :
494
+ # if there is a dedicated in-place kernel
495
+ # it can be called here, otherwise continues
496
+ if isinstance (o2 , dpt .usm_ndarray ):
497
+ src2 = o2
498
+ if (
499
+ ti ._array_overlap (o2 , out )
500
+ and not ti ._same_logical_tensors (o2 , out )
501
+ and buf2_dt is None
502
+ ):
503
+ buf2_dt = o2_dtype
504
+ else :
505
+ src2 = dpt .asarray (
506
+ o2 , dtype = o2_dtype , sycl_queue = exec_q
507
+ )
508
+ if buf2_dt is None :
509
+ if src2 .shape != res_shape :
510
+ src2 = dpt .broadcast_to (src2 , res_shape )
511
+ ht_ , _ = self .binary_inplace_fn_ (
512
+ lhs = o1 , rhs = src2 , sycl_queue = exec_q
513
+ )
514
+ ht_ .wait ()
515
+ else :
516
+ buf2 = dpt .empty_like (src2 , dtype = buf2_dt )
517
+ (
518
+ ht_copy_ev ,
519
+ copy_ev ,
520
+ ) = ti ._copy_usm_ndarray_into_usm_ndarray (
521
+ src = src2 , dst = buf2 , sycl_queue = exec_q
522
+ )
523
+
524
+ buf2 = dpt .broadcast_to (buf2 , res_shape )
525
+ ht_ , _ = self .binary_inplace_fn_ (
526
+ lhs = o1 ,
527
+ rhs = buf2 ,
528
+ sycl_queue = exec_q ,
529
+ depends = [copy_ev ],
530
+ )
531
+ ht_copy_ev .wait ()
532
+ ht_ .wait ()
533
+
534
+ return out
535
+
536
+ if isinstance (o2 , dpt .usm_ndarray ):
537
+ if (
538
+ ti ._array_overlap (o2 , out )
539
+ and not ti ._same_logical_tensors (o2 , out )
540
+ and buf2_dt is None
541
+ ):
542
+ # should not reach if out is reallocated
543
+ # after being checked against o1
544
+ out = dpt .empty_like (out )
545
+
500
546
if isinstance (o1 , dpt .usm_ndarray ):
501
547
src1 = o1
502
548
else :
@@ -532,19 +578,24 @@ def __call__(self, o1, o2, out=None, order="K"):
532
578
sycl_queue = exec_q ,
533
579
order = order ,
534
580
)
535
- else :
536
- if res_dt != out .dtype :
537
- raise TypeError (
538
- f"Output array of type { res_dt } is needed,"
539
- f"got { out .dtype } "
540
- )
541
-
542
- src1 = dpt .broadcast_to (src1 , res_shape )
543
- src2 = dpt .broadcast_to (src2 , res_shape )
544
- ht_ , _ = self .binary_fn_ (
581
+ if src1 .shape != res_shape :
582
+ src1 = dpt .broadcast_to (src1 , res_shape )
583
+ if src2 .shape != res_shape :
584
+ src2 = dpt .broadcast_to (src2 , res_shape )
585
+ ht_binary_ev , binary_ev = self .binary_fn_ (
545
586
src1 = src1 , src2 = src2 , dst = out , sycl_queue = exec_q
546
587
)
547
- ht_ .wait ()
588
+ if not (orig_out is None or orig_out is out ):
589
+ # Copy the out data from temporary buffer to original memory
590
+ ht_copy_out_ev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
591
+ src = out ,
592
+ dst = orig_out ,
593
+ sycl_queue = exec_q ,
594
+ depends = [binary_ev ],
595
+ )
596
+ ht_copy_out_ev .wait ()
597
+ out = orig_out
598
+ ht_binary_ev .wait ()
548
599
return out
549
600
elif buf1_dt is None :
550
601
if order == "K" :
@@ -575,18 +626,28 @@ def __call__(self, o1, o2, out=None, order="K"):
575
626
f"Output array of type { res_dt } is needed,"
576
627
f"got { out .dtype } "
577
628
)
578
-
579
- src1 = dpt .broadcast_to (src1 , res_shape )
629
+ if src1 . shape != res_shape :
630
+ src1 = dpt .broadcast_to (src1 , res_shape )
580
631
buf2 = dpt .broadcast_to (buf2 , res_shape )
581
- ht_ , _ = self .binary_fn_ (
632
+ ht_binary_ev , binary_ev = self .binary_fn_ (
582
633
src1 = src1 ,
583
634
src2 = buf2 ,
584
635
dst = out ,
585
636
sycl_queue = exec_q ,
586
637
depends = [copy_ev ],
587
638
)
639
+ if not (orig_out is None or orig_out is out ):
640
+ # Copy the out data from temporary buffer to original memory
641
+ ht_copy_out_ev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
642
+ src = out ,
643
+ dst = orig_out ,
644
+ sycl_queue = exec_q ,
645
+ depends = [binary_ev ],
646
+ )
647
+ ht_copy_out_ev .wait ()
648
+ out = orig_out
588
649
ht_copy_ev .wait ()
589
- ht_ .wait ()
650
+ ht_binary_ev .wait ()
590
651
return out
591
652
elif buf2_dt is None :
592
653
if order == "K" :
@@ -611,24 +672,29 @@ def __call__(self, o1, o2, out=None, order="K"):
611
672
sycl_queue = exec_q ,
612
673
order = order ,
613
674
)
614
- else :
615
- if res_dt != out .dtype :
616
- raise TypeError (
617
- f"Output array of type { res_dt } is needed,"
618
- f"got { out .dtype } "
619
- )
620
675
621
676
buf1 = dpt .broadcast_to (buf1 , res_shape )
622
- src2 = dpt .broadcast_to (src2 , res_shape )
623
- ht_ , _ = self .binary_fn_ (
677
+ if src2 .shape != res_shape :
678
+ src2 = dpt .broadcast_to (src2 , res_shape )
679
+ ht_binary_ev , binary_ev = self .binary_fn_ (
624
680
src1 = buf1 ,
625
681
src2 = src2 ,
626
682
dst = out ,
627
683
sycl_queue = exec_q ,
628
684
depends = [copy_ev ],
629
685
)
686
+ if not (orig_out is None or orig_out is out ):
687
+ # Copy the out data from temporary buffer to original memory
688
+ ht_copy_out_ev , _ = ti ._copy_usm_ndarray_into_usm_ndarray (
689
+ src = out ,
690
+ dst = orig_out ,
691
+ sycl_queue = exec_q ,
692
+ depends = [binary_ev ],
693
+ )
694
+ ht_copy_out_ev .wait ()
695
+ out = orig_out
630
696
ht_copy_ev .wait ()
631
- ht_ .wait ()
697
+ ht_binary_ev .wait ()
632
698
return out
633
699
634
700
if order in ["K" , "A" ]:
@@ -665,11 +731,6 @@ def __call__(self, o1, o2, out=None, order="K"):
665
731
sycl_queue = exec_q ,
666
732
order = order ,
667
733
)
668
- else :
669
- if res_dt != out .dtype :
670
- raise TypeError (
671
- f"Output array of type { res_dt } is needed, got { out .dtype } "
672
- )
673
734
674
735
buf1 = dpt .broadcast_to (buf1 , res_shape )
675
736
buf2 = dpt .broadcast_to (buf2 , res_shape )
@@ -682,116 +743,3 @@ def __call__(self, o1, o2, out=None, order="K"):
682
743
)
683
744
dpctl .SyclEvent .wait_for ([ht_copy1_ev , ht_copy2_ev , ht_ ])
684
745
return out
685
-
686
- def _inplace (self , lhs , val ):
687
- if self .binary_inplace_fn_ is None :
688
- raise ValueError (
689
- f"In-place operation not supported for ufunc '{ self .name_ } '"
690
- )
691
- if not isinstance (lhs , dpt .usm_ndarray ):
692
- raise TypeError (
693
- f"Expected dpctl.tensor.usm_ndarray, got { type (lhs )} "
694
- )
695
- q1 , lhs_usm_type = _get_queue_usm_type (lhs )
696
- q2 , val_usm_type = _get_queue_usm_type (val )
697
- if q2 is None :
698
- exec_q = q1
699
- usm_type = lhs_usm_type
700
- else :
701
- exec_q = dpctl .utils .get_execution_queue ((q1 , q2 ))
702
- if exec_q is None :
703
- raise ExecutionPlacementError (
704
- "Execution placement can not be unambiguously inferred "
705
- "from input arguments."
706
- )
707
- usm_type = dpctl .utils .get_coerced_usm_type (
708
- (
709
- lhs_usm_type ,
710
- val_usm_type ,
711
- )
712
- )
713
- dpctl .utils .validate_usm_type (usm_type , allow_none = False )
714
- lhs_shape = _get_shape (lhs )
715
- val_shape = _get_shape (val )
716
- if not all (
717
- isinstance (s , (tuple , list ))
718
- for s in (
719
- lhs_shape ,
720
- val_shape ,
721
- )
722
- ):
723
- raise TypeError (
724
- "Shape of arguments can not be inferred. "
725
- "Arguments are expected to be "
726
- "lists, tuples, or both"
727
- )
728
- try :
729
- res_shape = _broadcast_shape_impl (
730
- [
731
- lhs_shape ,
732
- val_shape ,
733
- ]
734
- )
735
- except ValueError :
736
- raise ValueError (
737
- "operands could not be broadcast together with shapes "
738
- f"{ lhs_shape } and { val_shape } "
739
- )
740
- if res_shape != lhs_shape :
741
- raise ValueError (
742
- f"output shape { lhs_shape } does not match "
743
- f"broadcast shape { res_shape } "
744
- )
745
- sycl_dev = exec_q .sycl_device
746
- lhs_dtype = lhs .dtype
747
- val_dtype = _get_dtype (val , sycl_dev )
748
- if not _validate_dtype (val_dtype ):
749
- raise ValueError ("Input operand of unsupported type" )
750
-
751
- lhs_dtype , val_dtype = _resolve_weak_types (
752
- lhs_dtype , val_dtype , sycl_dev
753
- )
754
-
755
- buf_dt = _find_inplace_dtype (
756
- lhs_dtype , val_dtype , self .result_type_resolver_fn_ , sycl_dev
757
- )
758
-
759
- if buf_dt is None :
760
- raise TypeError (
761
- f"In-place '{ self .name_ } ' does not support input types "
762
- f"({ lhs_dtype } , { val_dtype } ), "
763
- "and the inputs could not be safely coerced to any "
764
- "supported types according to the casting rule ''safe''."
765
- )
766
-
767
- if isinstance (val , dpt .usm_ndarray ):
768
- rhs = val
769
- overlap = ti ._array_overlap (lhs , rhs )
770
- else :
771
- rhs = dpt .asarray (val , dtype = val_dtype , sycl_queue = exec_q )
772
- overlap = False
773
-
774
- if buf_dt == val_dtype and overlap is False :
775
- rhs = dpt .broadcast_to (rhs , res_shape )
776
- ht_ , _ = self .binary_inplace_fn_ (
777
- lhs = lhs , rhs = rhs , sycl_queue = exec_q
778
- )
779
- ht_ .wait ()
780
-
781
- else :
782
- buf = dpt .empty_like (rhs , dtype = buf_dt )
783
- ht_copy_ev , copy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
784
- src = rhs , dst = buf , sycl_queue = exec_q
785
- )
786
-
787
- buf = dpt .broadcast_to (buf , res_shape )
788
- ht_ , _ = self .binary_inplace_fn_ (
789
- lhs = lhs ,
790
- rhs = buf ,
791
- sycl_queue = exec_q ,
792
- depends = [copy_ev ],
793
- )
794
- ht_copy_ev .wait ()
795
- ht_ .wait ()
796
-
797
- return lhs
0 commit comments