Skip to content

Commit aa0f963

Browse files
mathausefujiisoup
authored andcommitted
Feature/align in dot (#3699)
* add tests * implement align * whats new * fix changes to whats new * review: fix typos
1 parent 5c97641 commit aa0f963

File tree

4 files changed

+120
-1
lines changed

4 files changed

+120
-1
lines changed

doc/whats-new.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,15 @@ v0.15.0 (unreleased)
2121

2222
Breaking changes
2323
~~~~~~~~~~~~~~~~
24-
2524
- Remove ``compat`` and ``encoding`` kwargs from ``DataArray``, which
2625
have been deprecated since 0.12. (:pull:`3650`).
2726
Instead, specify the encoding when writing to disk or set
2827
the ``encoding`` attribute directly.
2928
By `Maximilian Roos <https://github.com/max-sixty>`_
29+
- :py:func:`xarray.dot`, :py:meth:`DataArray.dot`, and the ``@`` operator now
30+
use ``align="inner"`` (except when ``xarray.set_options(arithmetic_join="exact")``;
31+
:issue:`3694`) by `Mathias Hauser <https://github.com/mathause>`_.
32+
3033

3134
New Features
3235
~~~~~~~~~~~~

xarray/core/computation.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from . import duck_array_ops, utils
2727
from .alignment import deep_align
2828
from .merge import merge_coordinates_without_align
29+
from .options import OPTIONS
2930
from .pycompat import dask_array_type
3031
from .utils import is_dict_like
3132
from .variable import Variable
@@ -1175,6 +1176,11 @@ def dot(*arrays, dims=None, **kwargs):
11751176
subscripts = ",".join(subscripts_list)
11761177
subscripts += "->..." + "".join([dim_map[d] for d in output_core_dims[0]])
11771178

1179+
join = OPTIONS["arithmetic_join"]
1180+
# using "inner" emulates `(a * b).sum()` for all joins (except "exact")
1181+
if join != "exact":
1182+
join = "inner"
1183+
11781184
# subscripts should be passed to np.einsum as arg, not as kwargs. We need
11791185
# to construct a partial function for apply_ufunc to work.
11801186
func = functools.partial(duck_array_ops.einsum, subscripts, **kwargs)
@@ -1183,6 +1189,7 @@ def dot(*arrays, dims=None, **kwargs):
11831189
*arrays,
11841190
input_core_dims=input_core_dims,
11851191
output_core_dims=output_core_dims,
1192+
join=join,
11861193
dask="allowed",
11871194
)
11881195
return result.transpose(*[d for d in all_dims if d in result.dims])

xarray/tests/test_computation.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,60 @@ def test_dot(use_dask):
10431043
pickle.loads(pickle.dumps(xr.dot(da_a)))
10441044

10451045

1046+
@pytest.mark.parametrize("use_dask", [True, False])
1047+
def test_dot_align_coords(use_dask):
1048+
# GH 3694
1049+
1050+
if use_dask:
1051+
if not has_dask:
1052+
pytest.skip("test for dask.")
1053+
1054+
a = np.arange(30 * 4).reshape(30, 4)
1055+
b = np.arange(30 * 4 * 5).reshape(30, 4, 5)
1056+
1057+
# use partially overlapping coords
1058+
coords_a = {"a": np.arange(30), "b": np.arange(4)}
1059+
coords_b = {"a": np.arange(5, 35), "b": np.arange(1, 5)}
1060+
1061+
da_a = xr.DataArray(a, dims=["a", "b"], coords=coords_a)
1062+
da_b = xr.DataArray(b, dims=["a", "b", "c"], coords=coords_b)
1063+
1064+
if use_dask:
1065+
da_a = da_a.chunk({"a": 3})
1066+
da_b = da_b.chunk({"a": 3})
1067+
1068+
# join="inner" is the default
1069+
actual = xr.dot(da_a, da_b)
1070+
# `dot` sums over the common dimensions of the arguments
1071+
expected = (da_a * da_b).sum(["a", "b"])
1072+
xr.testing.assert_allclose(expected, actual)
1073+
1074+
actual = xr.dot(da_a, da_b, dims=...)
1075+
expected = (da_a * da_b).sum()
1076+
xr.testing.assert_allclose(expected, actual)
1077+
1078+
with xr.set_options(arithmetic_join="exact"):
1079+
with raises_regex(ValueError, "indexes along dimension"):
1080+
xr.dot(da_a, da_b)
1081+
1082+
# NOTE: dot always uses `join="inner"` because `(a * b).sum()` yields the same for all
1083+
# join method (except "exact")
1084+
with xr.set_options(arithmetic_join="left"):
1085+
actual = xr.dot(da_a, da_b)
1086+
expected = (da_a * da_b).sum(["a", "b"])
1087+
xr.testing.assert_allclose(expected, actual)
1088+
1089+
with xr.set_options(arithmetic_join="right"):
1090+
actual = xr.dot(da_a, da_b)
1091+
expected = (da_a * da_b).sum(["a", "b"])
1092+
xr.testing.assert_allclose(expected, actual)
1093+
1094+
with xr.set_options(arithmetic_join="outer"):
1095+
actual = xr.dot(da_a, da_b)
1096+
expected = (da_a * da_b).sum(["a", "b"])
1097+
xr.testing.assert_allclose(expected, actual)
1098+
1099+
10461100
def test_where():
10471101
cond = xr.DataArray([True, False], dims="x")
10481102
actual = xr.where(cond, 1, 0)

xarray/tests/test_dataarray.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3973,6 +3973,43 @@ def test_dot(self):
39733973
with pytest.raises(TypeError):
39743974
da.dot(dm.values)
39753975

3976+
def test_dot_align_coords(self):
3977+
# GH 3694
3978+
3979+
x = np.linspace(-3, 3, 6)
3980+
y = np.linspace(-3, 3, 5)
3981+
z_a = range(4)
3982+
da_vals = np.arange(6 * 5 * 4).reshape((6, 5, 4))
3983+
da = DataArray(da_vals, coords=[x, y, z_a], dims=["x", "y", "z"])
3984+
3985+
z_m = range(2, 6)
3986+
dm_vals = range(4)
3987+
dm = DataArray(dm_vals, coords=[z_m], dims=["z"])
3988+
3989+
with xr.set_options(arithmetic_join="exact"):
3990+
with raises_regex(ValueError, "indexes along dimension"):
3991+
da.dot(dm)
3992+
3993+
da_aligned, dm_aligned = xr.align(da, dm, join="inner")
3994+
3995+
# nd dot 1d
3996+
actual = da.dot(dm)
3997+
expected_vals = np.tensordot(da_aligned.values, dm_aligned.values, [2, 0])
3998+
expected = DataArray(expected_vals, coords=[x, da_aligned.y], dims=["x", "y"])
3999+
assert_equal(expected, actual)
4000+
4001+
# multiple shared dims
4002+
dm_vals = np.arange(20 * 5 * 4).reshape((20, 5, 4))
4003+
j = np.linspace(-3, 3, 20)
4004+
dm = DataArray(dm_vals, coords=[j, y, z_m], dims=["j", "y", "z"])
4005+
da_aligned, dm_aligned = xr.align(da, dm, join="inner")
4006+
actual = da.dot(dm)
4007+
expected_vals = np.tensordot(
4008+
da_aligned.values, dm_aligned.values, axes=([1, 2], [1, 2])
4009+
)
4010+
expected = DataArray(expected_vals, coords=[x, j], dims=["x", "j"])
4011+
assert_equal(expected, actual)
4012+
39764013
def test_matmul(self):
39774014

39784015
# copied from above (could make a fixture)
@@ -3986,6 +4023,24 @@ def test_matmul(self):
39864023
expected = da.dot(da)
39874024
assert_identical(result, expected)
39884025

4026+
def test_matmul_align_coords(self):
4027+
# GH 3694
4028+
4029+
x_a = np.arange(6)
4030+
x_b = np.arange(2, 8)
4031+
da_vals = np.arange(6)
4032+
da_a = DataArray(da_vals, coords=[x_a], dims=["x"])
4033+
da_b = DataArray(da_vals, coords=[x_b], dims=["x"])
4034+
4035+
# only test arithmetic_join="inner" (=default)
4036+
result = da_a @ da_b
4037+
expected = da_a.dot(da_b)
4038+
assert_identical(result, expected)
4039+
4040+
with xr.set_options(arithmetic_join="exact"):
4041+
with raises_regex(ValueError, "indexes along dimension"):
4042+
da_a @ da_b
4043+
39894044
def test_binary_op_propagate_indexes(self):
39904045
# regression test for GH2227
39914046
self.dv["x"] = np.arange(self.dv.sizes["x"])

0 commit comments

Comments
 (0)