Skip to content

Commit 29fd9e5

Browse files
Address NumPy 1.25 deprecation warnings
Ensure that ndarray that we converted usm_ndarray single element instance into is 0d before calling __int__, __float__, __complex__, __index__.
1 parent 8eab04b commit 29fd9e5

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ cdef class InternalUSMArrayError(Exception):
5252
pass
5353

5454

55+
cdef object _as_zero_dim_ndarray(object usm_ary):
56+
"Convert size-1 array to NumPy 0d array"
57+
mem_view = dpmem.as_usm_memory(usm_ary)
58+
host_buf = mem_view.copy_to_host()
59+
view = host_buf.view(usm_ary.dtype)
60+
view.shape = tuple()
61+
return view
62+
63+
5564
cdef class usm_ndarray:
5665
""" usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
5766
offset=0, order="C", buffer_ctor_kwargs=dict(), \
@@ -840,9 +849,7 @@ cdef class usm_ndarray:
840849

841850
def __bool__(self):
842851
if self.size == 1:
843-
mem_view = dpmem.as_usm_memory(self)
844-
host_buf = mem_view.copy_to_host()
845-
view = host_buf.view(self.dtype)
852+
view = _as_zero_dim_ndarray(self)
846853
return view.__bool__()
847854

848855
if self.size == 0:
@@ -857,9 +864,7 @@ cdef class usm_ndarray:
857864

858865
def __float__(self):
859866
if self.size == 1:
860-
mem_view = dpmem.as_usm_memory(self)
861-
host_buf = mem_view.copy_to_host()
862-
view = host_buf.view(self.dtype)
867+
view = _as_zero_dim_ndarray(self)
863868
return view.__float__()
864869

865870
raise ValueError(
@@ -868,9 +873,7 @@ cdef class usm_ndarray:
868873

869874
def __complex__(self):
870875
if self.size == 1:
871-
mem_view = dpmem.as_usm_memory(self)
872-
host_buf = mem_view.copy_to_host()
873-
view = host_buf.view(self.dtype)
876+
view = _as_zero_dim_ndarray(self)
874877
return view.__complex__()
875878

876879
raise ValueError(
@@ -879,9 +882,7 @@ cdef class usm_ndarray:
879882

880883
def __int__(self):
881884
if self.size == 1:
882-
mem_view = dpmem.as_usm_memory(self)
883-
host_buf = mem_view.copy_to_host()
884-
view = host_buf.view(self.dtype)
885+
view = _as_zero_dim_ndarray(self)
885886
return view.__int__()
886887

887888
raise ValueError(

0 commit comments

Comments
 (0)