Skip to content

Commit 95bca62

Browse files
chunweiyuanshoyer
authored andcommitted
Options to binary ops kwargs (#1065)
* Added join key to OPTIONS, used in dataarray & dataset binary ops, with a test module in test_dataarray.py * Added binary_ops test to test_dataset. * Changed variable names according to review comments. * Changed default key to arithmetic_join, and shortened tests. Also added to computation.rst and whats-new.rst * Emphasis on context manager for xr.set_options() use. * Changed to actual vs expected testing nomenclature. * Applies join options to Dataset.data_vars as well. * Preserve order of joined data_vars, left-to-right. Use np.nan as default filler. * PEP8 and doctring.
1 parent 1ff3a0b commit 95bca62

File tree

7 files changed

+95
-20
lines changed

7 files changed

+95
-20
lines changed

doc/computation.rst

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ coordinates with the same name as a dimension, marked by ``*``) on objects used
210210
in binary operations.
211211

212212
Similarly to pandas, this alignment is automatic for arithmetic on binary
213-
operations. Note that unlike pandas, this the result of a binary operation is
214-
by the *intersection* (not the union) of coordinate labels:
213+
operations. The default result of a binary operation is by the *intersection*
214+
(not the union) of coordinate labels:
215215

216216
.. ipython:: python
217217
@@ -225,6 +225,15 @@ If the result would be empty, an error is raised instead:
225225
In [1]: arr[:2] + arr[2:]
226226
ValueError: no overlapping labels for some dimensions: ['x']
227227

228+
However, one can explicitly change this default automatic alignment type ("inner")
229+
via :py:func:`~xarray.set_options()` in context manager:
230+
231+
.. ipython:: python
232+
233+
with xr.set_options(arithmetic_join="outer"):
234+
arr + arr[:1]
235+
arr + arr[:1]
236+
228237
Before loops or performance critical code, it's a good idea to align arrays
229238
explicitly (e.g., by putting them in the same Dataset or using
230239
:py:func:`~xarray.align`) to avoid the overhead of repeated alignment with each

doc/whats-new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ Deprecations
4444

4545
Enhancements
4646
~~~~~~~~~~~~
47+
- Added the ability to change default automatic alignment (arithmetic_join="inner")
48+
for binary operations via :py:func:`~xarray.set_options()`
49+
(see :ref:`automatic alignment`).
50+
By `Chun-Wei Yuan <https://github.com/chunweiyuan>`_.
51+
4752
- Add checking of ``attr`` names and values when saving to netCDF, raising useful
4853
error messages if they are invalid. (:issue:`911`).
4954
By `Robin Wilson <https://github.com/robintw>`_.

xarray/core/dataarray.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
assert_unique_multiindex_level_names)
2626
from .formatting import format_item
2727
from .utils import decode_numpy_dict_values, ensure_us_time_resolution
28+
from .options import OPTIONS
2829

2930

3031
def _infer_coords_and_dims(shape, coords, dims):
@@ -1357,13 +1358,14 @@ def func(self, *args, **kwargs):
13571358
return func
13581359

13591360
@staticmethod
1360-
def _binary_op(f, reflexive=False, join='inner', **ignored_kwargs):
1361+
def _binary_op(f, reflexive=False, join=None, **ignored_kwargs):
13611362
@functools.wraps(f)
13621363
def func(self, other):
13631364
if isinstance(other, (Dataset, groupby.GroupBy)):
13641365
return NotImplemented
13651366
if hasattr(other, 'indexes'):
1366-
self, other = align(self, other, join=join, copy=False)
1367+
align_type = OPTIONS['arithmetic_join'] if join is None else join
1368+
self, other = align(self, other, join=align_type, copy=False)
13671369
other_variable = getattr(other, 'variable', other)
13681370
other_coords = getattr(other, 'coords', None)
13691371

xarray/core/dataset.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .pycompat import (iteritems, basestring, OrderedDict,
2929
dask_array_type)
3030
from .combine import concat
31+
from .options import OPTIONS
3132

3233

3334
# list of attributes of pd.DatetimeIndex that are ndarrays of time info
@@ -2012,15 +2013,17 @@ def func(self, *args, **kwargs):
20122013
return func
20132014

20142015
@staticmethod
2015-
def _binary_op(f, reflexive=False, join='inner', fillna=False):
2016+
def _binary_op(f, reflexive=False, join=None, fillna=False):
20162017
@functools.wraps(f)
20172018
def func(self, other):
20182019
if isinstance(other, groupby.GroupBy):
20192020
return NotImplemented
2021+
align_type = OPTIONS['arithmetic_join'] if join is None else join
20202022
if hasattr(other, 'indexes'):
2021-
self, other = align(self, other, join=join, copy=False)
2023+
self, other = align(self, other, join=align_type, copy=False)
20222024
g = f if not reflexive else lambda x, y: f(y, x)
2023-
ds = self._calculate_binary_op(g, other, fillna=fillna)
2025+
ds = self._calculate_binary_op(g, other, join=align_type,
2026+
fillna=fillna)
20242027
return ds
20252028
return func
20262029

@@ -2042,25 +2045,32 @@ def func(self, other):
20422045
return self
20432046
return func
20442047

2045-
def _calculate_binary_op(self, f, other, inplace=False, fillna=False):
2048+
def _calculate_binary_op(self, f, other, join='inner',
2049+
inplace=False, fillna=False):
20462050

20472051
def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars):
2052+
if fillna and join != 'left':
2053+
raise ValueError('`fillna` must be accompanied by left join')
20482054
if fillna and not set(rhs_data_vars) <= set(lhs_data_vars):
20492055
raise ValueError('all variables in the argument to `fillna` '
20502056
'must be contained in the original dataset')
2057+
if inplace and set(lhs_data_vars) != set(rhs_data_vars):
2058+
raise ValueError('datasets must have the same data variables '
2059+
'for in-place arithmetic operations: %s, %s'
2060+
% (list(lhs_data_vars), list(rhs_data_vars)))
20512061

20522062
dest_vars = OrderedDict()
2063+
20532064
for k in lhs_data_vars:
20542065
if k in rhs_data_vars:
20552066
dest_vars[k] = f(lhs_vars[k], rhs_vars[k])
2056-
elif inplace:
2057-
raise ValueError(
2058-
'datasets must have the same data variables '
2059-
'for in-place arithmetic operations: %s, %s'
2060-
% (list(lhs_data_vars), list(rhs_data_vars)))
2061-
elif fillna:
2062-
# this shortcuts left alignment of variables for fillna
2063-
dest_vars[k] = lhs_vars[k]
2067+
elif join in ["left", "outer"]:
2068+
dest_vars[k] = (lhs_vars[k] if fillna else
2069+
f(lhs_vars[k], np.nan))
2070+
for k in rhs_data_vars:
2071+
if k not in dest_vars and join in ["right", "outer"]:
2072+
dest_vars[k] = (rhs_vars[k] if fillna else
2073+
f(rhs_vars[k], np.nan))
20642074
return dest_vars
20652075

20662076
if utils.is_dict_like(other) and not isinstance(other, Dataset):
@@ -2080,7 +2090,6 @@ def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars):
20802090
other_variable = getattr(other, 'variable', other)
20812091
new_vars = OrderedDict((k, f(self.variables[k], other_variable))
20822092
for k in self.data_vars)
2083-
20842093
ds._variables.update(new_vars)
20852094
return ds
20862095

xarray/core/options.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
from __future__ import absolute_import
22
from __future__ import division
33
from __future__ import print_function
4-
OPTIONS = {'display_width': 80}
4+
OPTIONS = {'display_width': 80,
5+
'arithmetic_join': "inner"}
56

67

78
class set_options(object):
89
"""Set options for xarray in a controlled context.
910
10-
Currently, the only supported option is ``display_width``, which has a
11-
default value of 80.
11+
Currently, the only supported options are:
12+
1.) display_width: maximum terminal display width of data arrays.
13+
Default=80.
14+
2.) arithmetic_join: dataarray/dataset alignment in binary operations.
15+
Default='inner'.
1216
1317
You can use ``set_options`` either as a context manager:
1418

xarray/test/test_dataarray.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2278,3 +2278,17 @@ def test_dot(self):
22782278
da.dot(dm.values)
22792279
with self.assertRaisesRegexp(ValueError, 'no shared dimensions'):
22802280
da.dot(DataArray(1))
2281+
2282+
def test_binary_op_join_setting(self):
2283+
dim = 'x'
2284+
align_type = "outer"
2285+
coords_l, coords_r = [0, 1, 2], [1, 2, 3]
2286+
missing_3 = xr.DataArray(coords_l, [(dim, coords_l)])
2287+
missing_0 = xr.DataArray(coords_r, [(dim, coords_r)])
2288+
with xr.set_options(arithmetic_join=align_type):
2289+
actual = missing_0 + missing_3
2290+
missing_0_aligned, missing_3_aligned = xr.align(missing_0,
2291+
missing_3,
2292+
join=align_type)
2293+
expected = xr.DataArray([np.nan, 2, 4, np.nan], [(dim, [0, 1, 2, 3])])
2294+
self.assertDataArrayEqual(actual, expected)

xarray/test/test_dataset.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import numpy as np
1717
import pandas as pd
18+
import xarray as xr
1819
import pytest
1920

2021
from xarray import (align, broadcast, concat, merge, conventions, backends,
@@ -2935,6 +2936,37 @@ def test_filter_by_attrs(self):
29352936
for var in new_ds.data_vars:
29362937
self.assertEqual(new_ds[var].height, '10 m')
29372938

2939+
def test_binary_op_join_setting(self):
2940+
# arithmetic_join applies to data array coordinates
2941+
missing_2 = xr.Dataset({'x':[0, 1]})
2942+
missing_0 = xr.Dataset({'x':[1, 2]})
2943+
with xr.set_options(arithmetic_join='outer'):
2944+
actual = missing_2 + missing_0
2945+
expected = xr.Dataset({'x':[0, 1, 2]})
2946+
self.assertDatasetEqual(actual, expected)
2947+
2948+
# arithmetic join also applies to data_vars
2949+
ds1 = xr.Dataset({'foo': 1, 'bar': 2})
2950+
ds2 = xr.Dataset({'bar': 2, 'baz': 3})
2951+
expected = xr.Dataset({'bar': 4}) # default is inner joining
2952+
actual = ds1 + ds2
2953+
self.assertDatasetEqual(actual, expected)
2954+
2955+
with xr.set_options(arithmetic_join='outer'):
2956+
expected = xr.Dataset({'foo': np.nan, 'bar': 4, 'baz': np.nan})
2957+
actual = ds1 + ds2
2958+
self.assertDatasetEqual(actual, expected)
2959+
2960+
with xr.set_options(arithmetic_join='left'):
2961+
expected = xr.Dataset({'foo': np.nan, 'bar': 4})
2962+
actual = ds1 + ds2
2963+
self.assertDatasetEqual(actual, expected)
2964+
2965+
with xr.set_options(arithmetic_join='right'):
2966+
expected = xr.Dataset({'bar': 4, 'baz': np.nan})
2967+
actual = ds1 + ds2
2968+
self.assertDatasetEqual(actual, expected)
2969+
29382970

29392971
### Py.test tests
29402972

0 commit comments

Comments
 (0)