Skip to content

Commit 865adc9

Browse files
authored
Merge pull request #2983 from oesteban/fix/pickling-outputmultiobjects
FIX: Correctly pickle ``OuputMulti{Object,Path}`` traits
2 parents 7262b24 + 0cd60b6 commit 865adc9

File tree

2 files changed

+30
-88
lines changed

2 files changed

+30
-88
lines changed

nipype/interfaces/base/specs.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
Undefined,
3131
isdefined,
3232
has_metadata,
33+
OutputMultiObject,
3334
)
3435

3536
from ... import config, __version__
@@ -316,6 +317,32 @@ def _get_sorteddict(self,
316317
def __all__(self):
317318
return self.copyable_trait_names()
318319

320+
def __getstate__(self):
321+
"""
322+
Override __getstate__ so that OutputMultiObjects are correctly pickled.
323+
324+
>>> class OutputSpec(TraitedSpec):
325+
... out = OutputMultiObject(traits.List(traits.Int))
326+
>>> spec = OutputSpec()
327+
>>> spec.out = [[4]]
328+
>>> spec.out
329+
[4]
330+
331+
>>> spec.__getstate__()['out']
332+
[[4]]
333+
334+
>>> spec.__setstate__(spec.__getstate__())
335+
>>> spec.out
336+
[4]
337+
338+
"""
339+
state = super(BaseTraitedSpec, self).__getstate__()
340+
for key in self.__all__:
341+
_trait_spec = self.trait(key)
342+
if _trait_spec.is_trait_type(OutputMultiObject):
343+
state[key] = _trait_spec.handler.get_value(self, key)
344+
return state
345+
319346

320347
def _deepcopypatch(self, memo):
321348
"""

nipype/pipeline/engine/utils.py

Lines changed: 3 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -226,87 +226,12 @@ def write_report(node, report_type=None, is_mapnode=False):
226226
return
227227

228228

229-
def _identify_collapses(hastraits):
230-
""" Identify traits that will collapse when being set to themselves.
231-
232-
``OutputMultiObject``s automatically unwrap a list of length 1 to directly
233-
reference the element of that list.
234-
If that element is itself a list of length 1, then the following will
235-
result in modified values.
236-
237-
hastraits.trait_set(**hastraits.trait_get())
238-
239-
Cloning performs this operation on a copy of the original traited object,
240-
allowing us to identify traits that will be affected.
241-
"""
242-
raw = hastraits.trait_get()
243-
cloned = hastraits.clone_traits().trait_get()
244-
245-
collapsed = set()
246-
for key in cloned:
247-
orig = raw[key]
248-
new = cloned[key]
249-
# Allow numpy to handle the equality checks, as mixed lists and arrays
250-
# can be problematic.
251-
if isinstance(orig, list) and len(orig) == 1 and (
252-
not np.array_equal(orig, new) and np.array_equal(orig[0], new)):
253-
collapsed.add(key)
254-
255-
return collapsed
256-
257-
258-
def _uncollapse(indexable, collapsed):
259-
""" Wrap collapsible values in a list to prevent double-collapsing.
260-
261-
Should be used with _identify_collapses to provide the following
262-
idempotent operation:
263-
264-
collapsed = _identify_collapses(hastraits)
265-
hastraits.trait_set(**_uncollapse(hastraits.trait_get(), collapsed))
266-
267-
NOTE: Modifies object in-place, in addition to returning it.
268-
"""
269-
270-
for key in indexable:
271-
if key in collapsed:
272-
indexable[key] = [indexable[key]]
273-
return indexable
274-
275-
276-
def _protect_collapses(hastraits):
277-
""" A collapse-protected replacement for hastraits.trait_get()
278-
279-
May be used as follows to provide an idempotent trait_set:
280-
281-
hastraits.trait_set(**_protect_collapses(hastraits))
282-
"""
283-
collapsed = _identify_collapses(hastraits)
284-
return _uncollapse(hastraits.trait_get(), collapsed)
285-
286-
287229
def save_resultfile(result, cwd, name):
288230
"""Save a result pklz file to ``cwd``"""
289231
resultsfile = os.path.join(cwd, 'result_%s.pklz' % name)
290-
if result.outputs:
291-
try:
292-
collapsed = _identify_collapses(result.outputs)
293-
outputs = _uncollapse(result.outputs.trait_get(), collapsed)
294-
# Double-protect tosave so that the original, uncollapsed trait
295-
# is saved in the pickle file. Thus, when the loading process
296-
# collapses, the original correct value is loaded.
297-
tosave = _uncollapse(outputs.copy(), collapsed)
298-
except AttributeError:
299-
tosave = outputs = result.outputs.dictcopy() # outputs was a bunch
300-
for k, v in list(modify_paths(tosave, relative=True, basedir=cwd).items()):
301-
setattr(result.outputs, k, v)
302-
303232
savepkl(resultsfile, result)
304233
logger.debug('saved results in %s', resultsfile)
305234

306-
if result.outputs:
307-
for k, v in list(outputs.items()):
308-
setattr(result.outputs, k, v)
309-
310235

311236
def load_resultfile(path, name):
312237
"""
@@ -349,20 +274,10 @@ def load_resultfile(path, name):
349274
logger.debug(
350275
'some file does not exist. hence trait cannot be set')
351276
else:
352-
if result.outputs:
353-
try:
354-
outputs = _protect_collapses(result.outputs)
355-
except AttributeError:
356-
outputs = result.outputs.dictcopy() # outputs == Bunch
357-
try:
358-
for k, v in list(modify_paths(outputs, relative=False,
359-
basedir=path).items()):
360-
setattr(result.outputs, k, v)
361-
except FileNotFoundError:
362-
logger.debug('conversion to full path results in '
363-
'non existent file')
364277
aggregate = False
365-
pkl_file.close()
278+
finally:
279+
pkl_file.close()
280+
366281
logger.debug('Aggregate: %s', aggregate)
367282
return result, aggregate, attribute_error
368283

0 commit comments

Comments
 (0)