2
2
3
3
import contextlib
4
4
import logging
5
+ import multiprocessing
6
+ import threading
5
7
import time
6
8
import traceback
7
9
import warnings
14
16
from ..core .pycompat import dask_array_type , iteritems
15
17
from ..core .utils import FrozenOrderedDict , NdimSizeLenMixin
16
18
19
+ # Import default lock
17
20
try :
18
- from dask .utils import SerializableLock as Lock
21
+ from dask .utils import SerializableLock
22
+ HDF5_LOCK = SerializableLock ()
19
23
except ImportError :
20
- from threading import Lock
21
-
24
+ HDF5_LOCK = threading .Lock ()
22
25
23
26
# Create a logger object, but don't add any handlers. Leave that to user code.
24
27
logger = logging .getLogger (__name__ )
27
30
NONE_VAR_NAME = '__values__'
28
31
29
32
30
- # dask.utils.SerializableLock if available, otherwise just a threading.Lock
31
- GLOBAL_LOCK = Lock ()
33
+ def get_scheduler (get = None , collection = None ):
34
+ """ Determine the dask scheduler that is being used.
35
+
36
+ None is returned if not dask scheduler is active.
37
+
38
+ See also
39
+ --------
40
+ dask.utils.effective_get
41
+ """
42
+ try :
43
+ from dask .utils import effective_get
44
+ actual_get = effective_get (get , collection )
45
+ try :
46
+ from dask .distributed import Client
47
+ if isinstance (actual_get .__self__ , Client ):
48
+ return 'distributed'
49
+ except (ImportError , AttributeError ):
50
+ try :
51
+ import dask .multiprocessing
52
+ if actual_get == dask .multiprocessing .get :
53
+ return 'multiprocessing'
54
+ else :
55
+ return 'threaded'
56
+ except ImportError :
57
+ return 'threaded'
58
+ except ImportError :
59
+ return None
60
+
61
+
62
+ def get_scheduler_lock (scheduler , path_or_file = None ):
63
+ """ Get the appropriate lock for a certain situation based onthe dask
64
+ scheduler used.
65
+
66
+ See Also
67
+ --------
68
+ dask.utils.get_scheduler_lock
69
+ """
70
+
71
+ if scheduler == 'distributed' :
72
+ from dask .distributed import Lock
73
+ return Lock (path_or_file )
74
+ elif scheduler == 'multiprocessing' :
75
+ return multiprocessing .Lock ()
76
+ elif scheduler == 'threaded' :
77
+ from dask .utils import SerializableLock
78
+ return SerializableLock ()
79
+ else :
80
+ return threading .Lock ()
32
81
33
82
34
83
def _encode_variable_name (name ):
@@ -77,6 +126,39 @@ def robust_getitem(array, key, catch=Exception, max_retries=6,
77
126
time .sleep (1e-3 * next_delay )
78
127
79
128
129
+ class CombinedLock (object ):
130
+ """A combination of multiple locks.
131
+
132
+ Like a locked door, a CombinedLock is locked if any of its constituent
133
+ locks are locked.
134
+ """
135
+
136
+ def __init__ (self , locks ):
137
+ self .locks = tuple (set (locks )) # remove duplicates
138
+
139
+ def acquire (self , * args ):
140
+ return all (lock .acquire (* args ) for lock in self .locks )
141
+
142
+ def release (self , * args ):
143
+ for lock in self .locks :
144
+ lock .release (* args )
145
+
146
+ def __enter__ (self ):
147
+ for lock in self .locks :
148
+ lock .__enter__ ()
149
+
150
+ def __exit__ (self , * args ):
151
+ for lock in self .locks :
152
+ lock .__exit__ (* args )
153
+
154
+ @property
155
+ def locked (self ):
156
+ return any (lock .locked for lock in self .locks )
157
+
158
+ def __repr__ (self ):
159
+ return "CombinedLock(%r)" % list (self .locks )
160
+
161
+
80
162
class BackendArray (NdimSizeLenMixin , indexing .ExplicitlyIndexed ):
81
163
82
164
def __array__ (self , dtype = None ):
@@ -85,7 +167,9 @@ def __array__(self, dtype=None):
85
167
86
168
87
169
class AbstractDataStore (Mapping ):
88
- _autoclose = False
170
+ _autoclose = None
171
+ _ds = None
172
+ _isopen = False
89
173
90
174
def __iter__ (self ):
91
175
return iter (self .variables )
@@ -168,7 +252,7 @@ def __exit__(self, exception_type, exception_value, traceback):
168
252
169
253
170
254
class ArrayWriter (object ):
171
- def __init__ (self , lock = GLOBAL_LOCK ):
255
+ def __init__ (self , lock = HDF5_LOCK ):
172
256
self .sources = []
173
257
self .targets = []
174
258
self .lock = lock
@@ -178,11 +262,7 @@ def add(self, source, target):
178
262
self .sources .append (source )
179
263
self .targets .append (target )
180
264
else :
181
- try :
182
- target [...] = source
183
- except TypeError :
184
- # workaround for GH: scipy/scipy#6880
185
- target [:] = source
265
+ target [...] = source
186
266
187
267
def sync (self ):
188
268
if self .sources :
@@ -193,9 +273,9 @@ def sync(self):
193
273
194
274
195
275
class AbstractWritableDataStore (AbstractDataStore ):
196
- def __init__ (self , writer = None ):
276
+ def __init__ (self , writer = None , lock = HDF5_LOCK ):
197
277
if writer is None :
198
- writer = ArrayWriter ()
278
+ writer = ArrayWriter (lock = lock )
199
279
self .writer = writer
200
280
201
281
def encode (self , variables , attributes ):
@@ -239,6 +319,9 @@ def set_variable(self, k, v): # pragma: no cover
239
319
raise NotImplementedError
240
320
241
321
def sync (self ):
322
+ if self ._isopen and self ._autoclose :
323
+ # datastore will be reopened during write
324
+ self .close ()
242
325
self .writer .sync ()
243
326
244
327
def store_dataset (self , dataset ):
@@ -373,27 +456,41 @@ class DataStorePickleMixin(object):
373
456
374
457
def __getstate__ (self ):
375
458
state = self .__dict__ .copy ()
376
- del state ['ds' ]
459
+ del state ['_ds' ]
460
+ del state ['_isopen' ]
377
461
if self ._mode == 'w' :
378
462
# file has already been created, don't override when restoring
379
463
state ['_mode' ] = 'a'
380
464
return state
381
465
382
466
def __setstate__ (self , state ):
383
467
self .__dict__ .update (state )
384
- self .ds = self ._opener (mode = self ._mode )
468
+ self ._ds = None
469
+ self ._isopen = False
470
+
471
+ @property
472
+ def ds (self ):
473
+ if self ._ds is not None and self ._isopen :
474
+ return self ._ds
475
+ ds = self ._opener (mode = self ._mode )
476
+ self ._isopen = True
477
+ return ds
385
478
386
479
@contextlib .contextmanager
387
- def ensure_open (self , autoclose ):
480
+ def ensure_open (self , autoclose = None ):
388
481
"""
389
482
Helper function to make sure datasets are closed and opened
390
483
at appropriate times to avoid too many open file errors.
391
484
392
485
Use requires `autoclose=True` argument to `open_mfdataset`.
393
486
"""
394
- if self ._autoclose and not self ._isopen :
487
+
488
+ if autoclose is None :
489
+ autoclose = self ._autoclose
490
+
491
+ if not self ._isopen :
395
492
try :
396
- self .ds = self ._opener ()
493
+ self ._ds = self ._opener ()
397
494
self ._isopen = True
398
495
yield
399
496
finally :
0 commit comments