Skip to content

Commit 2f590f7

Browse files
author
Joe Hamman
authored
fix distributed writes (#1793)
* distributed tests that write dask arrays * Change zarr test to synchronous API * initial go at __setitem__ on array wrappers * fixes for scipy * cleanup after merging with upstream/master * needless duplication of tests to work around pytest bug * use netcdf_variable instead of get_array() * use synchronous dask.distributed test harness * cleanup tests * per scheduler locks and autoclose behavior for writes * HDF5_LOCK and CombinedLock * integration test for distributed locks * more tests and set isopen to false when pickling * Fixing style errors. * ds property on DataStorePickleMixin * stickler-ci * compat fixes for other backends * HDF5_USE_FILE_LOCKING = False in test_distributed * style fix * update tests to only expect netcdf4 to work, docstrings, and some cleanup in to_netcdf * Fixing style errors. * fix imports after merge * fix more import bugs * update docs * fix for pynio * cleanup locks and use pytest monkeypatch for environment variable * fix failing test using combined lock
1 parent 8c6a284 commit 2f590f7

File tree

12 files changed

+331
-64
lines changed

12 files changed

+331
-64
lines changed

doc/dask.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,14 @@ Once you've manipulated a dask array, you can still write a dataset too big to
100100
fit into memory back to disk by using :py:meth:`~xarray.Dataset.to_netcdf` in the
101101
usual way.
102102

103+
.. note::
104+
105+
When using dask's distributed scheduler to write NETCDF4 files,
106+
it may be necessary to set the environment variable `HDF5_USE_FILE_LOCKING=FALSE`
107+
to avoid competing locks within the HDF5 SWMR file locking scheme. Note that
108+
writing netCDF files with dask's distributed scheduler is only supported for
109+
the `netcdf4` backend.
110+
103111
A dataset can also be converted to a dask DataFrame using :py:meth:`~xarray.Dataset.to_dask_dataframe`.
104112

105113
.. ipython:: python

doc/io.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -672,9 +672,9 @@ files into a single Dataset by making use of :py:func:`~xarray.concat`.
672672

673673
.. note::
674674

675-
Version 0.5 includes support for manipulating datasets that
676-
don't fit into memory with dask_. If you have dask installed, you can open
677-
multiple files simultaneously using :py:func:`~xarray.open_mfdataset`::
675+
Xarray includes support for manipulating datasets that don't fit into memory
676+
with dask_. If you have dask installed, you can open multiple files
677+
simultaneously using :py:func:`~xarray.open_mfdataset`::
678678

679679
xr.open_mfdataset('my/files/*.nc')
680680

doc/whats-new.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ Documentation
3838
Enhancements
3939
~~~~~~~~~~~~
4040

41+
- Support for writing xarray datasets to netCDF files (netcdf4 backend only)
42+
when using the `dask.distributed <https://distributed.readthedocs.io>`_
43+
scheduler (:issue:`1464`).
44+
By `Joe Hamman <https://github.com/jhamman>`_.
45+
46+
47+
- Fixed to_netcdf when using dask distributed
4148
- Support lazy vectorized-indexing. After this change, flexible indexing such
4249
as orthogonal/vectorized indexing, becomes possible for all the backend
4350
arrays. Also, lazy ``transpose`` is now also supported. (:issue:`1897`)

xarray/backends/api.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from ..core.combine import auto_combine
1313
from ..core.pycompat import basestring, path_type
1414
from ..core.utils import close_on_error, is_remote_uri
15-
from .common import GLOBAL_LOCK, ArrayWriter
15+
from .common import (
16+
HDF5_LOCK, ArrayWriter, CombinedLock, get_scheduler, get_scheduler_lock)
1617

1718
DATAARRAY_NAME = '__xarray_dataarray_name__'
1819
DATAARRAY_VARIABLE = '__xarray_dataarray_variable__'
@@ -64,9 +65,9 @@ def _default_lock(filename, engine):
6465
else:
6566
# TODO: identify netcdf3 files and don't use the global lock
6667
# for them
67-
lock = GLOBAL_LOCK
68+
lock = HDF5_LOCK
6869
elif engine in {'h5netcdf', 'pynio'}:
69-
lock = GLOBAL_LOCK
70+
lock = HDF5_LOCK
7071
else:
7172
lock = False
7273
return lock
@@ -129,6 +130,20 @@ def _protect_dataset_variables_inplace(dataset, cache):
129130
variable.data = data
130131

131132

133+
def _get_lock(engine, scheduler, format, path_or_file):
134+
""" Get the lock(s) that apply to a particular scheduler/engine/format"""
135+
136+
locks = []
137+
if format in ['NETCDF4', None] and engine in ['h5netcdf', 'netcdf4']:
138+
locks.append(HDF5_LOCK)
139+
locks.append(get_scheduler_lock(scheduler, path_or_file))
140+
141+
# When we have more than one lock, use the CombinedLock wrapper class
142+
lock = CombinedLock(locks) if len(locks) > 1 else locks[0]
143+
144+
return lock
145+
146+
132147
def open_dataset(filename_or_obj, group=None, decode_cf=True,
133148
mask_and_scale=True, decode_times=True, autoclose=False,
134149
concat_characters=True, decode_coords=True, engine=None,
@@ -620,8 +635,20 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None,
620635
# if a writer is provided, store asynchronously
621636
sync = writer is None
622637

638+
# handle scheduler specific logic
639+
scheduler = get_scheduler()
640+
if (dataset.chunks and scheduler in ['distributed', 'multiprocessing'] and
641+
engine != 'netcdf4'):
642+
raise NotImplementedError("Writing netCDF files with the %s backend "
643+
"is not currently supported with dask's %s "
644+
"scheduler" % (engine, scheduler))
645+
lock = _get_lock(engine, scheduler, format, path_or_file)
646+
autoclose = (dataset.chunks and
647+
scheduler in ['distributed', 'multiprocessing'])
648+
623649
target = path_or_file if path_or_file is not None else BytesIO()
624-
store = store_open(target, mode, format, group, writer)
650+
store = store_open(target, mode, format, group, writer,
651+
autoclose=autoclose, lock=lock)
625652

626653
if unlimited_dims is None:
627654
unlimited_dims = dataset.encoding.get('unlimited_dims', None)

xarray/backends/common.py

Lines changed: 116 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import contextlib
44
import logging
5+
import multiprocessing
6+
import threading
57
import time
68
import traceback
79
import warnings
@@ -14,11 +16,12 @@
1416
from ..core.pycompat import dask_array_type, iteritems
1517
from ..core.utils import FrozenOrderedDict, NdimSizeLenMixin
1618

19+
# Import default lock
1720
try:
18-
from dask.utils import SerializableLock as Lock
21+
from dask.utils import SerializableLock
22+
HDF5_LOCK = SerializableLock()
1923
except ImportError:
20-
from threading import Lock
21-
24+
HDF5_LOCK = threading.Lock()
2225

2326
# Create a logger object, but don't add any handlers. Leave that to user code.
2427
logger = logging.getLogger(__name__)
@@ -27,8 +30,54 @@
2730
NONE_VAR_NAME = '__values__'
2831

2932

30-
# dask.utils.SerializableLock if available, otherwise just a threading.Lock
31-
GLOBAL_LOCK = Lock()
33+
def get_scheduler(get=None, collection=None):
34+
""" Determine the dask scheduler that is being used.
35+
36+
None is returned if not dask scheduler is active.
37+
38+
See also
39+
--------
40+
dask.utils.effective_get
41+
"""
42+
try:
43+
from dask.utils import effective_get
44+
actual_get = effective_get(get, collection)
45+
try:
46+
from dask.distributed import Client
47+
if isinstance(actual_get.__self__, Client):
48+
return 'distributed'
49+
except (ImportError, AttributeError):
50+
try:
51+
import dask.multiprocessing
52+
if actual_get == dask.multiprocessing.get:
53+
return 'multiprocessing'
54+
else:
55+
return 'threaded'
56+
except ImportError:
57+
return 'threaded'
58+
except ImportError:
59+
return None
60+
61+
62+
def get_scheduler_lock(scheduler, path_or_file=None):
63+
""" Get the appropriate lock for a certain situation based onthe dask
64+
scheduler used.
65+
66+
See Also
67+
--------
68+
dask.utils.get_scheduler_lock
69+
"""
70+
71+
if scheduler == 'distributed':
72+
from dask.distributed import Lock
73+
return Lock(path_or_file)
74+
elif scheduler == 'multiprocessing':
75+
return multiprocessing.Lock()
76+
elif scheduler == 'threaded':
77+
from dask.utils import SerializableLock
78+
return SerializableLock()
79+
else:
80+
return threading.Lock()
3281

3382

3483
def _encode_variable_name(name):
@@ -77,6 +126,39 @@ def robust_getitem(array, key, catch=Exception, max_retries=6,
77126
time.sleep(1e-3 * next_delay)
78127

79128

129+
class CombinedLock(object):
130+
"""A combination of multiple locks.
131+
132+
Like a locked door, a CombinedLock is locked if any of its constituent
133+
locks are locked.
134+
"""
135+
136+
def __init__(self, locks):
137+
self.locks = tuple(set(locks)) # remove duplicates
138+
139+
def acquire(self, *args):
140+
return all(lock.acquire(*args) for lock in self.locks)
141+
142+
def release(self, *args):
143+
for lock in self.locks:
144+
lock.release(*args)
145+
146+
def __enter__(self):
147+
for lock in self.locks:
148+
lock.__enter__()
149+
150+
def __exit__(self, *args):
151+
for lock in self.locks:
152+
lock.__exit__(*args)
153+
154+
@property
155+
def locked(self):
156+
return any(lock.locked for lock in self.locks)
157+
158+
def __repr__(self):
159+
return "CombinedLock(%r)" % list(self.locks)
160+
161+
80162
class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed):
81163

82164
def __array__(self, dtype=None):
@@ -85,7 +167,9 @@ def __array__(self, dtype=None):
85167

86168

87169
class AbstractDataStore(Mapping):
88-
_autoclose = False
170+
_autoclose = None
171+
_ds = None
172+
_isopen = False
89173

90174
def __iter__(self):
91175
return iter(self.variables)
@@ -168,7 +252,7 @@ def __exit__(self, exception_type, exception_value, traceback):
168252

169253

170254
class ArrayWriter(object):
171-
def __init__(self, lock=GLOBAL_LOCK):
255+
def __init__(self, lock=HDF5_LOCK):
172256
self.sources = []
173257
self.targets = []
174258
self.lock = lock
@@ -178,11 +262,7 @@ def add(self, source, target):
178262
self.sources.append(source)
179263
self.targets.append(target)
180264
else:
181-
try:
182-
target[...] = source
183-
except TypeError:
184-
# workaround for GH: scipy/scipy#6880
185-
target[:] = source
265+
target[...] = source
186266

187267
def sync(self):
188268
if self.sources:
@@ -193,9 +273,9 @@ def sync(self):
193273

194274

195275
class AbstractWritableDataStore(AbstractDataStore):
196-
def __init__(self, writer=None):
276+
def __init__(self, writer=None, lock=HDF5_LOCK):
197277
if writer is None:
198-
writer = ArrayWriter()
278+
writer = ArrayWriter(lock=lock)
199279
self.writer = writer
200280

201281
def encode(self, variables, attributes):
@@ -239,6 +319,9 @@ def set_variable(self, k, v): # pragma: no cover
239319
raise NotImplementedError
240320

241321
def sync(self):
322+
if self._isopen and self._autoclose:
323+
# datastore will be reopened during write
324+
self.close()
242325
self.writer.sync()
243326

244327
def store_dataset(self, dataset):
@@ -373,27 +456,41 @@ class DataStorePickleMixin(object):
373456

374457
def __getstate__(self):
375458
state = self.__dict__.copy()
376-
del state['ds']
459+
del state['_ds']
460+
del state['_isopen']
377461
if self._mode == 'w':
378462
# file has already been created, don't override when restoring
379463
state['_mode'] = 'a'
380464
return state
381465

382466
def __setstate__(self, state):
383467
self.__dict__.update(state)
384-
self.ds = self._opener(mode=self._mode)
468+
self._ds = None
469+
self._isopen = False
470+
471+
@property
472+
def ds(self):
473+
if self._ds is not None and self._isopen:
474+
return self._ds
475+
ds = self._opener(mode=self._mode)
476+
self._isopen = True
477+
return ds
385478

386479
@contextlib.contextmanager
387-
def ensure_open(self, autoclose):
480+
def ensure_open(self, autoclose=None):
388481
"""
389482
Helper function to make sure datasets are closed and opened
390483
at appropriate times to avoid too many open file errors.
391484
392485
Use requires `autoclose=True` argument to `open_mfdataset`.
393486
"""
394-
if self._autoclose and not self._isopen:
487+
488+
if autoclose is None:
489+
autoclose = self._autoclose
490+
491+
if not self._isopen:
395492
try:
396-
self.ds = self._opener()
493+
self._ds = self._opener()
397494
self._isopen = True
398495
yield
399496
finally:

xarray/backends/h5netcdf_.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from ..core import indexing
99
from ..core.pycompat import OrderedDict, bytes_type, iteritems, unicode_type
1010
from ..core.utils import FrozenOrderedDict, close_on_error
11-
from .common import DataStorePickleMixin, WritableCFDataStore, find_root
11+
from .common import (
12+
HDF5_LOCK, DataStorePickleMixin, WritableCFDataStore, find_root)
1213
from .netCDF4_ import (
1314
BaseNetCDF4Array, _encode_nc4_variable, _extract_nc4_variable_encoding,
1415
_get_datatype, _nc4_group)
@@ -68,12 +69,12 @@ class H5NetCDFStore(WritableCFDataStore, DataStorePickleMixin):
6869
"""
6970

7071
def __init__(self, filename, mode='r', format=None, group=None,
71-
writer=None, autoclose=False):
72+
writer=None, autoclose=False, lock=HDF5_LOCK):
7273
if format not in [None, 'NETCDF4']:
7374
raise ValueError('invalid format for h5netcdf backend')
7475
opener = functools.partial(_open_h5netcdf_group, filename, mode=mode,
7576
group=group)
76-
self.ds = opener()
77+
self._ds = opener()
7778
if autoclose:
7879
raise NotImplementedError('autoclose=True is not implemented '
7980
'for the h5netcdf backend pending '
@@ -85,7 +86,7 @@ def __init__(self, filename, mode='r', format=None, group=None,
8586
self._opener = opener
8687
self._filename = filename
8788
self._mode = mode
88-
super(H5NetCDFStore, self).__init__(writer)
89+
super(H5NetCDFStore, self).__init__(writer, lock=lock)
8990

9091
def open_store_variable(self, name, var):
9192
with self.ensure_open(autoclose=False):
@@ -177,7 +178,10 @@ def prepare_variable(self, name, variable, check_encoding=False,
177178

178179
for k, v in iteritems(attrs):
179180
nc4_var.setncattr(k, v)
180-
return nc4_var, variable.data
181+
182+
target = H5NetCDFArrayWrapper(name, self)
183+
184+
return target, variable.data
181185

182186
def sync(self):
183187
with self.ensure_open(autoclose=True):

0 commit comments

Comments
 (0)