Skip to content

Commit cf4660d

Browse files
Merge pull request #1368 from IntelPython/scalar_special_methods_required_zero_dim_ndarray
Address NumPy 1.25 deprecation warnings
2 parents 1595dce + a653eb3 commit cf4660d

File tree

2 files changed

+19
-16
lines changed

2 files changed

+19
-16
lines changed

dpctl/tensor/_usmarray.pyx

+13-12
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ cdef class InternalUSMArrayError(Exception):
5252
pass
5353

5454

55+
cdef object _as_zero_dim_ndarray(object usm_ary):
56+
"Convert size-1 array to NumPy 0d array"
57+
mem_view = dpmem.as_usm_memory(usm_ary)
58+
host_buf = mem_view.copy_to_host()
59+
view = host_buf.view(usm_ary.dtype)
60+
view.shape = tuple()
61+
return view
62+
63+
5564
cdef class usm_ndarray:
5665
""" usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
5766
offset=0, order="C", buffer_ctor_kwargs=dict(), \
@@ -840,9 +849,7 @@ cdef class usm_ndarray:
840849

841850
def __bool__(self):
842851
if self.size == 1:
843-
mem_view = dpmem.as_usm_memory(self)
844-
host_buf = mem_view.copy_to_host()
845-
view = host_buf.view(self.dtype)
852+
view = _as_zero_dim_ndarray(self)
846853
return view.__bool__()
847854

848855
if self.size == 0:
@@ -857,9 +864,7 @@ cdef class usm_ndarray:
857864

858865
def __float__(self):
859866
if self.size == 1:
860-
mem_view = dpmem.as_usm_memory(self)
861-
host_buf = mem_view.copy_to_host()
862-
view = host_buf.view(self.dtype)
867+
view = _as_zero_dim_ndarray(self)
863868
return view.__float__()
864869

865870
raise ValueError(
@@ -868,9 +873,7 @@ cdef class usm_ndarray:
868873

869874
def __complex__(self):
870875
if self.size == 1:
871-
mem_view = dpmem.as_usm_memory(self)
872-
host_buf = mem_view.copy_to_host()
873-
view = host_buf.view(self.dtype)
876+
view = _as_zero_dim_ndarray(self)
874877
return view.__complex__()
875878

876879
raise ValueError(
@@ -879,9 +882,7 @@ cdef class usm_ndarray:
879882

880883
def __int__(self):
881884
if self.size == 1:
882-
mem_view = dpmem.as_usm_memory(self)
883-
host_buf = mem_view.copy_to_host()
884-
view = host_buf.view(self.dtype)
885+
view = _as_zero_dim_ndarray(self)
885886
return view.__int__()
886887

887888
raise ValueError(

dpctl/tests/test_usm_ndarray_ctor.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,9 @@ def test_copy_scalar_with_func(func, shape, dtype):
239239
X = dpt.usm_ndarray(shape, dtype=dtype)
240240
except dpctl.SyclDeviceCreationError:
241241
pytest.skip("No SYCL devices available")
242-
Y = np.arange(1, X.size + 1, dtype=dtype).reshape(shape)
243-
X.usm_data.copy_from_host(Y.reshape(-1).view("|u1"))
242+
Y = np.arange(1, X.size + 1, dtype=dtype)
243+
X.usm_data.copy_from_host(Y.view("|u1"))
244+
Y.shape = tuple()
244245
assert func(X) == func(Y)
245246

246247

@@ -254,8 +255,9 @@ def test_copy_scalar_with_method(method, shape, dtype):
254255
X = dpt.usm_ndarray(shape, dtype=dtype)
255256
except dpctl.SyclDeviceCreationError:
256257
pytest.skip("No SYCL devices available")
257-
Y = np.arange(1, X.size + 1, dtype=dtype).reshape(shape)
258-
X.usm_data.copy_from_host(Y.reshape(-1).view("|u1"))
258+
Y = np.arange(1, X.size + 1, dtype=dtype)
259+
X.usm_data.copy_from_host(Y.view("|u1"))
260+
Y.shape = tuple()
259261
assert getattr(X, method)() == getattr(Y, method)()
260262

261263

0 commit comments

Comments
 (0)