diff --git a/dpctl/tensor/_copy_utils.py b/dpctl/tensor/_copy_utils.py index 3eae29f057..ad5b956851 100644 --- a/dpctl/tensor/_copy_utils.py +++ b/dpctl/tensor/_copy_utils.py @@ -517,6 +517,11 @@ def copy(usm_ary, order="K"): - "K": match the layout of `usm_ary` as closely as possible. """ + if len(order) == 0 or order[0] not in "KkAaCcFf": + raise ValueError( + "Unrecognized order keyword value, expecting 'K', 'A', 'F', or 'C'." + ) + order = order[0].upper() if not isinstance(usm_ary, dpt.usm_ndarray): return TypeError( f"Expected object of type dpt.usm_ndarray, got {type(usm_ary)}" @@ -585,11 +590,11 @@ def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True): return TypeError( f"Expected object of type dpt.usm_ndarray, got {type(usm_ary)}" ) - if not isinstance(order, str) or order not in ["A", "C", "F", "K"]: + if len(order) == 0 or order[0] not in "KkAaCcFf": raise ValueError( - "Unrecognized value of the order keyword. " - "Recognized values are 'A', 'C', 'F', or 'K'" + "Unrecognized order keyword value, expecting 'K', 'A', 'F', or 'C'." ) + order = order[0].upper() ary_dtype = usm_ary.dtype target_dtype = _get_dtype(newdtype, usm_ary.sycl_queue) if not dpt.can_cast(ary_dtype, target_dtype, casting=casting):