29
29
from ._copy_utils import _empty_like_orderK , _empty_like_pair_orderK
30
30
from ._type_utils import (
31
31
_acceptance_fn_default ,
32
+ _all_data_types ,
32
33
_find_buf_dtype ,
33
34
_find_buf_dtype2 ,
34
35
_to_device_supported_dtype ,
@@ -44,6 +45,7 @@ def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
44
45
self .__name__ = "UnaryElementwiseFunc"
45
46
self .name_ = name
46
47
self .result_type_resolver_fn_ = result_type_resolver_fn
48
+ self .types_ = None
47
49
self .unary_fn_ = unary_dp_impl_fn
48
50
self .__doc__ = docs
49
51
@@ -53,6 +55,18 @@ def __str__(self):
53
55
def __repr__ (self ):
54
56
return f"<{ self .__name__ } '{ self .name_ } '>"
55
57
58
+ @property
59
+ def types (self ):
60
+ types = self .types_
61
+ if not types :
62
+ types = []
63
+ for dt1 in _all_data_types (True , True ):
64
+ dt2 = self .result_type_resolver_fn_ (dt1 )
65
+ if dt2 :
66
+ types .append (f"{ dt1 .char } ->{ dt2 .char } " )
67
+ self .types_ = types
68
+ return types
69
+
56
70
def __call__ (self , x , out = None , order = "K" ):
57
71
if not isinstance (x , dpt .usm_ndarray ):
58
72
raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
@@ -363,6 +377,7 @@ def __init__(
363
377
self .__name__ = "BinaryElementwiseFunc"
364
378
self .name_ = name
365
379
self .result_type_resolver_fn_ = result_type_resolver_fn
380
+ self .types_ = None
366
381
self .binary_fn_ = binary_dp_impl_fn
367
382
self .binary_inplace_fn_ = binary_inplace_fn
368
383
self .__doc__ = docs
@@ -377,6 +392,20 @@ def __str__(self):
377
392
def __repr__ (self ):
378
393
return f"<{ self .__name__ } '{ self .name_ } '>"
379
394
395
+ @property
396
+ def types (self ):
397
+ types = self .types_
398
+ if not types :
399
+ types = []
400
+ _all_dtypes = _all_data_types (True , True )
401
+ for dt1 in _all_dtypes :
402
+ for dt2 in _all_dtypes :
403
+ dt3 = self .result_type_resolver_fn_ (dt1 , dt2 )
404
+ if dt3 :
405
+ types .append (f"{ dt1 .char } { dt2 .char } ->{ dt3 .char } " )
406
+ self .types_ = types
407
+ return types
408
+
380
409
def __call__ (self , o1 , o2 , out = None , order = "K" ):
381
410
if order not in ["K" , "C" , "F" , "A" ]:
382
411
order = "K"
0 commit comments