Skip to content

Commit abc8c80

Browse files
authored
_real_view and _imag_view now set flags correctly (#1355)
- these properties were setting the flags of the output to the flags of the input, which is incorrect, as the output is almost never contiguous - added tests for this behavior
1 parent 83858f0 commit abc8c80

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

dpctl/tensor/_usmarray.pyx

+5-5
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ cdef class usm_ndarray:
662662
"""
663663
if self.nd_ < 2:
664664
raise ValueError(
665-
"array.mT requires array to have at least 2-dimensons."
665+
"array.mT requires array to have at least 2 dimensions."
666666
)
667667
return _m_transpose(self)
668668

@@ -1216,14 +1216,14 @@ cdef usm_ndarray _real_view(usm_ndarray ary):
12161216
offset_elems = ary.get_offset() * 2
12171217
r = usm_ndarray.__new__(
12181218
usm_ndarray,
1219-
_make_int_tuple(ary.nd_, ary.shape_),
1219+
_make_int_tuple(ary.nd_, ary.shape_) if ary.nd_ > 0 else tuple(),
12201220
dtype=_make_typestr(r_typenum_),
12211221
strides=tuple(2 * si for si in ary.strides),
12221222
buffer=ary.base_,
12231223
offset=offset_elems,
12241224
order=('C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F')
12251225
)
1226-
r.flags_ = ary.flags_
1226+
r.flags_ |= (ary.flags_ & USM_ARRAY_WRITABLE)
12271227
r.array_namespace_ = ary.array_namespace_
12281228
return r
12291229

@@ -1248,14 +1248,14 @@ cdef usm_ndarray _imag_view(usm_ndarray ary):
12481248
offset_elems = 2 * ary.get_offset() + 1
12491249
r = usm_ndarray.__new__(
12501250
usm_ndarray,
1251-
_make_int_tuple(ary.nd_, ary.shape_),
1251+
_make_int_tuple(ary.nd_, ary.shape_) if ary.nd_ > 0 else tuple(),
12521252
dtype=_make_typestr(r_typenum_),
12531253
strides=tuple(2 * si for si in ary.strides),
12541254
buffer=ary.base_,
12551255
offset=offset_elems,
12561256
order=('C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else 'F')
12571257
)
1258-
r.flags_ = ary.flags_
1258+
r.flags_ |= (ary.flags_ & USM_ARRAY_WRITABLE)
12591259
r.array_namespace_ = ary.array_namespace_
12601260
return r
12611261

dpctl/tests/test_usm_ndarray_ctor.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1438,17 +1438,26 @@ def test_real_imag_views():
14381438
n, m = 2, 3
14391439
try:
14401440
X = dpt.usm_ndarray((n, m), "c8")
1441+
X_scalar = dpt.usm_ndarray((), dtype="c8")
14411442
except dpctl.SyclDeviceCreationError:
14421443
pytest.skip("No SYCL devices available")
14431444
Xnp_r = np.arange(n * m, dtype="f4").reshape((n, m))
14441445
Xnp_i = np.arange(n * m, 2 * n * m, dtype="f4").reshape((n, m))
14451446
Xnp = Xnp_r + 1j * Xnp_i
14461447
X[:] = Xnp
1447-
assert np.array_equal(dpt.to_numpy(X.real), Xnp.real)
1448+
X_real = X.real
1449+
X_imag = X.imag
1450+
assert np.array_equal(dpt.to_numpy(X_real), Xnp.real)
14481451
assert np.array_equal(dpt.to_numpy(X.imag), Xnp.imag)
1452+
assert not X_real.flags["C"] and not X_real.flags["F"]
1453+
assert not X_imag.flags["C"] and not X_imag.flags["F"]
1454+
assert X_real.strides == X_imag.strides
14491455
assert np.array_equal(dpt.to_numpy(X[1:].real), Xnp[1:].real)
14501456
assert np.array_equal(dpt.to_numpy(X[1:].imag), Xnp[1:].imag)
14511457

1458+
X_scalar[...] = complex(n * m, 2 * n * m)
1459+
assert X_scalar.real and X_scalar.imag
1460+
14521461

14531462
@pytest.mark.parametrize(
14541463
"dtype",

0 commit comments

Comments
 (0)