Skip to content

Commit 4cc552f

Browse files
authored
Implements types property for elementwise functions (#1361)
* Implements ``types`` property for elementwise functions - Output corresponds with Numpy's: a list with an arrow marking the domain to range type map * Added tests for behavior of types property
1 parent abc8c80 commit 4cc552f

File tree

3 files changed

+47
-0
lines changed

3 files changed

+47
-0
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK
3030
from ._type_utils import (
3131
_acceptance_fn_default,
32+
_all_data_types,
3233
_find_buf_dtype,
3334
_find_buf_dtype2,
3435
_to_device_supported_dtype,
@@ -44,6 +45,7 @@ def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
4445
self.__name__ = "UnaryElementwiseFunc"
4546
self.name_ = name
4647
self.result_type_resolver_fn_ = result_type_resolver_fn
48+
self.types_ = None
4749
self.unary_fn_ = unary_dp_impl_fn
4850
self.__doc__ = docs
4951

@@ -53,6 +55,18 @@ def __str__(self):
5355
def __repr__(self):
5456
return f"<{self.__name__} '{self.name_}'>"
5557

58+
@property
59+
def types(self):
60+
types = self.types_
61+
if not types:
62+
types = []
63+
for dt1 in _all_data_types(True, True):
64+
dt2 = self.result_type_resolver_fn_(dt1)
65+
if dt2:
66+
types.append(f"{dt1.char}->{dt2.char}")
67+
self.types_ = types
68+
return types
69+
5670
def __call__(self, x, out=None, order="K"):
5771
if not isinstance(x, dpt.usm_ndarray):
5872
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
@@ -363,6 +377,7 @@ def __init__(
363377
self.__name__ = "BinaryElementwiseFunc"
364378
self.name_ = name
365379
self.result_type_resolver_fn_ = result_type_resolver_fn
380+
self.types_ = None
366381
self.binary_fn_ = binary_dp_impl_fn
367382
self.binary_inplace_fn_ = binary_inplace_fn
368383
self.__doc__ = docs
@@ -377,6 +392,20 @@ def __str__(self):
377392
def __repr__(self):
378393
return f"<{self.__name__} '{self.name_}'>"
379394

395+
@property
396+
def types(self):
397+
types = self.types_
398+
if not types:
399+
types = []
400+
_all_dtypes = _all_data_types(True, True)
401+
for dt1 in _all_dtypes:
402+
for dt2 in _all_dtypes:
403+
dt3 = self.result_type_resolver_fn_(dt1, dt2)
404+
if dt3:
405+
types.append(f"{dt1.char}{dt2.char}->{dt3.char}")
406+
self.types_ = types
407+
return types
408+
380409
def __call__(self, o1, o2, out=None, order="K"):
381410
if order not in ["K", "C", "F", "A"]:
382411
order = "K"

dpctl/tests/elementwise/test_abs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,15 @@ def test_abs_usm_type(usm_type):
7676
assert np.allclose(dpt.asnumpy(Y), expected_Y)
7777

7878

79+
def test_abs_types_prop():
80+
types = dpt.abs.types_
81+
assert types is None
82+
types = dpt.abs.types
83+
assert isinstance(types, list)
84+
assert len(types) > 0
85+
assert types == dpt.abs.types_
86+
87+
7988
@pytest.mark.parametrize("dtype", _all_dtypes[1:])
8089
def test_abs_order(dtype):
8190
q = get_queue_or_skip()

dpctl/tests/elementwise/test_add.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,15 @@ def __sycl_usm_array_interface__(self):
258258
dpt.add(a, c)
259259

260260

261+
def test_add_types_property():
262+
types = dpt.add.types_
263+
assert types is None
264+
types = dpt.add.types
265+
assert isinstance(types, list)
266+
assert len(types) > 0
267+
assert types == dpt.add.types_
268+
269+
261270
def test_add_errors():
262271
get_queue_or_skip()
263272
try:

0 commit comments

Comments
 (0)