Skip to content

Commit 46ab60d

Browse files
committed
add pint to the list of supported types in SupportsArithmetic
1 parent c91c2a9 commit 46ab60d

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

xarray/core/arithmetic.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55

66
from .options import OPTIONS
7-
from .pycompat import dask_array_type
7+
from .pycompat import dask_array_type, pint_array_type
88
from .utils import not_implemented
99

1010

@@ -22,22 +22,18 @@ class SupportsArithmetic:
2222

2323
# TODO: allow extending this with some sort of registration system
2424
_HANDLED_TYPES = (
25-
np.ndarray,
26-
np.generic,
27-
numbers.Number,
28-
bytes,
29-
str,
30-
) + dask_array_type
25+
(np.ndarray, np.generic, numbers.Number, bytes, str,)
26+
+ dask_array_type
27+
+ pint_array_type
28+
)
3129

3230
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
3331
from .computation import apply_ufunc
3432

3533
# See the docstring example for numpy.lib.mixins.NDArrayOperatorsMixin.
3634
out = kwargs.get("out", ())
3735
for x in inputs + out:
38-
if not isinstance(
39-
x, self._HANDLED_TYPES + (SupportsArithmetic,)
40-
) and not hasattr(x, "__array_ufunc__"):
36+
if not isinstance(x, self._HANDLED_TYPES + (SupportsArithmetic,)):
4137
return NotImplemented
4238

4339
if ufunc.signature is not None:

xarray/core/pycompat.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,11 @@
1717
sparse_array_type = (sparse.SparseArray,)
1818
except ImportError: # pragma: no cover
1919
sparse_array_type = ()
20+
21+
try:
22+
# solely for isinstance checks
23+
import pint
24+
25+
pint_array_type = (pint.Quantity,)
26+
except ImportError: # pragma: no cover
27+
pint_array_type = ()

0 commit comments

Comments
 (0)