@@ -52,6 +52,15 @@ cdef class InternalUSMArrayError(Exception):
52
52
pass
53
53
54
54
55
+ cdef object _as_zero_dim_ndarray(object usm_ary):
56
+ " Convert size-1 array to NumPy 0d array"
57
+ mem_view = dpmem.as_usm_memory(usm_ary)
58
+ host_buf = mem_view.copy_to_host()
59
+ view = host_buf.view(usm_ary.dtype)
60
+ view.shape = tuple ()
61
+ return view
62
+
63
+
55
64
cdef class usm_ndarray:
56
65
""" usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
57
66
offset=0, order="C", buffer_ctor_kwargs=dict(), \
@@ -840,9 +849,7 @@ cdef class usm_ndarray:
840
849
841
850
def __bool__ (self ):
842
851
if self .size == 1 :
843
- mem_view = dpmem.as_usm_memory(self )
844
- host_buf = mem_view.copy_to_host()
845
- view = host_buf.view(self .dtype)
852
+ view = _as_zero_dim_ndarray(self )
846
853
return view.__bool__()
847
854
848
855
if self .size == 0 :
@@ -857,9 +864,7 @@ cdef class usm_ndarray:
857
864
858
865
def __float__ (self ):
859
866
if self .size == 1 :
860
- mem_view = dpmem.as_usm_memory(self )
861
- host_buf = mem_view.copy_to_host()
862
- view = host_buf.view(self .dtype)
867
+ view = _as_zero_dim_ndarray(self )
863
868
return view.__float__ ()
864
869
865
870
raise ValueError (
@@ -868,9 +873,7 @@ cdef class usm_ndarray:
868
873
869
874
def __complex__ (self ):
870
875
if self .size == 1 :
871
- mem_view = dpmem.as_usm_memory(self )
872
- host_buf = mem_view.copy_to_host()
873
- view = host_buf.view(self .dtype)
876
+ view = _as_zero_dim_ndarray(self )
874
877
return view.__complex__ ()
875
878
876
879
raise ValueError (
@@ -879,9 +882,7 @@ cdef class usm_ndarray:
879
882
880
883
def __int__ (self ):
881
884
if self .size == 1 :
882
- mem_view = dpmem.as_usm_memory(self )
883
- host_buf = mem_view.copy_to_host()
884
- view = host_buf.view(self .dtype)
885
+ view = _as_zero_dim_ndarray(self )
885
886
return view.__int__ ()
886
887
887
888
raise ValueError (
0 commit comments