Skip to content

Commit 2a6bb8f

Browse files
FredLoneysatra
FredLoney
authored andcommitted
Handle the alternate synchronize iterables format.
Make iterables validation more robust.
1 parent fa08d85 commit 2a6bb8f

File tree

2 files changed

+130
-44
lines changed

2 files changed

+130
-44
lines changed

nipype/pipeline/tests/test_engine.py

+59
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,65 @@ def test_itersource_expansion():
269269
# => 3 * 14 = 42 nodes in the group
270270
yield assert_equal, len(pe.generate_expanded_graph(wf3._flatgraph).nodes()), 42
271271

272+
def test_itersource_synchronize1_expansion():
273+
import nipype.pipeline.engine as pe
274+
wf1 = pe.Workflow(name='test')
275+
node1 = pe.Node(TestInterface(),name='node1')
276+
node1.iterables = [('input1',[1,2]), ('input2',[3,4])]
277+
node1.synchronize = True
278+
node2 = pe.Node(TestInterface(),name='node2')
279+
wf1.connect(node1,'output1', node2, 'input1')
280+
node3 = pe.Node(TestInterface(),name='node3')
281+
node3.itersource = ('node1', ['input1', 'input2'])
282+
node3.iterables = [('input1', {(1,3):[5,6]}), ('input2', {(1,3):[7,8], (2,4): [9]})]
283+
wf1.connect(node2,'output1', node3, 'input1')
284+
node4 = pe.Node(TestInterface(),name='node4')
285+
wf1.connect(node3,'output1', node4, 'input1')
286+
wf3 = pe.Workflow(name='group')
287+
for i in [0,1,2]:
288+
wf3.add_nodes([wf1.clone(name='test%d'%i)])
289+
wf3._flatgraph = wf3._create_flat_graph()
290+
291+
# each expanded graph clone has:
292+
# 2 node1 expansion nodes,
293+
# 1 node2 per node1 replicate,
294+
# 2 node3 replicates for the node1 input1 value 1,
295+
# 3 node3 replicates for the node1 input1 value 2 and
296+
# 1 node4 successor per node3 replicate
297+
# => 2 + 2 + (2 + 3) + 5 = 14 nodes per expanded graph clone
298+
# => 3 * 14 = 42 nodes in the group
299+
yield assert_equal, len(pe.generate_expanded_graph(wf3._flatgraph).nodes()), 42
300+
301+
def test_itersource_synchronize2_expansion():
302+
import nipype.pipeline.engine as pe
303+
wf1 = pe.Workflow(name='test')
304+
node1 = pe.Node(TestInterface(),name='node1')
305+
node1.iterables = [('input1',[1,2]), ('input2',[3,4])]
306+
node1.synchronize = True
307+
node2 = pe.Node(TestInterface(),name='node2')
308+
wf1.connect(node1,'output1', node2, 'input1')
309+
node3 = pe.Node(TestInterface(),name='node3')
310+
node3.itersource = ('node1', ['input1', 'input2'])
311+
node3.synchronize = True
312+
node3.iterables = [('input1', 'input2'), {(1,3):[(5,7), (6,8)], (2,4):[(None,9)]}]
313+
wf1.connect(node2,'output1', node3, 'input1')
314+
node4 = pe.Node(TestInterface(),name='node4')
315+
wf1.connect(node3,'output1', node4, 'input1')
316+
wf3 = pe.Workflow(name='group')
317+
for i in [0,1,2]:
318+
wf3.add_nodes([wf1.clone(name='test%d'%i)])
319+
wf3._flatgraph = wf3._create_flat_graph()
320+
321+
# each expanded graph clone has:
322+
# 2 node1 expansion nodes,
323+
# 1 node2 per node1 replicate,
324+
# 2 node3 replicates for the node1 input1 value 1,
325+
# 1 node3 replicates for the node1 input1 value 2 and
326+
# 1 node4 successor per node3 replicate
327+
# => 2 + 2 + (2 + 1) + 3 = 10 nodes per expanded graph clone
328+
# => 3 * 10 = 30 nodes in the group
329+
yield assert_equal, len(pe.generate_expanded_graph(wf3._flatgraph).nodes()), 30
330+
272331
def test_disconnect():
273332
import nipype.pipeline.engine as pe
274333
from nipype.interfaces.utility import IdentityInterface

nipype/pipeline/utils.py

+71-44
Original file line numberDiff line numberDiff line change
@@ -645,14 +645,15 @@ def generate_expanded_graph(graph_in):
645645
key = src_values[0]
646646
else:
647647
key = tuple(src_values)
648-
# the iterables is a {field: lambda} dictionary, where the
649-
# lambda returns a {source key: iteration list} dictionary
650-
iterables = {}
651-
for field, func in inode.iterables.iteritems():
652-
# the {source key: iteration list} dictionary
653-
lookup = func()
654-
if lookup.has_key(key):
655-
iterables[field] = lambda: lookup[key]
648+
# The itersource iterables is a {field: lookup} dictionary, where the
649+
# lookup is a {source key: iteration list} dictionary. Look up the
650+
# current iterable value using the predecessor itersource input values.
651+
iter_dict = {field: lookup[key] for field, lookup in inode.iterables
652+
if lookup.has_key(key)}
653+
# convert the iterables to the standard {field: function} format
654+
iter_items = map(lambda(field, value): (field, lambda: value),
655+
iter_dict.iteritems())
656+
iterables = dict(iter_items)
656657
else:
657658
iterables = inode.iterables.copy()
658659
inode.iterables = None
@@ -800,56 +801,82 @@ def _standardize_iterables(node):
800801
iterables = node.iterables
801802
# The candidate iterable fields
802803
fields = set(node.inputs.copyable_trait_names())
803-
804-
# Synchronize iterables can be in [fields, value tuples] format
805-
# rather than [(field, value list), (field, value list), ...]
806-
if node.synchronize and len(iterables) == 2:
807-
first, last = iterables
808-
if all((isinstance(item, str) and item in fields
809-
for item in first)):
810-
iterables = _transpose_iterables(first, last)
811-
804+
# Flag indicating whether the iterables are in the alternate
805+
# synchronize form and are not converted to a standard format.
806+
synchronize = False
807+
# A synchronize iterables node without an itersource can be in
808+
# [fields, value tuples] format rather than
809+
# [(field, value list), (field, value list), ...]
810+
if node.synchronize:
811+
if len(iterables) == 2:
812+
first, last = iterables
813+
if all((isinstance(item, str) and item in fields
814+
for item in first)):
815+
iterables = _transpose_iterables(first, last)
816+
812817
# Convert a tuple to a list
813818
if isinstance(iterables, tuple):
814819
iterables = [iterables]
820+
# Validate the standard [(field, values)] format
821+
_validate_iterables(node, iterables, fields)
815822
# Convert a list to a dictionary
816823
if isinstance(iterables, list):
817-
# Validate the format
818-
for item in iterables:
819-
try:
820-
if len(item) != 2:
821-
raise ValueError("The %s iterables do not consist of"
822-
" (field, values) pairs" % node.name)
823-
except TypeError, e:
824-
raise TypeError("The %s iterables is not iterable: %s"
825-
% (node.name, e))
826-
# Convert the values to functions. This is a legacy Nipype
827-
# requirement with unknown rationale.
828-
iter_items = map(lambda(field, value): (field, lambda: value),
829-
iterables)
830-
# Make the iterables dictionary
831-
iterables = dict(iter_items)
832-
elif not isinstance(iterables, dict):
824+
# Convert a values list to a function. This is a legacy
825+
# Nipype requirement with unknown rationale.
826+
if not node.itersource:
827+
iter_items = map(lambda(field, value): (field, lambda: value),
828+
iterables)
829+
iterables = dict(iter_items)
830+
node.iterables = iterables
831+
832+
def _validate_iterables(node, iterables, fields):
833+
"""
834+
Raise TypeError if an iterables member is not iterable.
835+
836+
Raise ValueError if an iterables member is not a (field, values) pair.
837+
838+
Raise ValueError if an iterable field is not in the inputs.
839+
"""
840+
# The iterables can be a {field: value list} dictionary.
841+
if isinstance(iterables, dict):
842+
iterables = iterables.items()
843+
elif not isinstance(iterables, tuple) and not isinstance(iterables, list):
833844
raise ValueError("The %s iterables type is not a list or a dictionary:"
834845
" %s" % (node.name, iterables.__class__))
835-
836-
# Validate the iterable fields
837-
for field in iterables.iterkeys():
846+
for item in iterables:
847+
try:
848+
if len(item) != 2:
849+
raise ValueError("The %s iterables is not a [(field, values)]"
850+
" list" % node.name)
851+
except TypeError, e:
852+
raise TypeError("A %s iterables member is not iterable: %s"
853+
% (node.name, e))
854+
field, _ = item
838855
if field not in fields:
839856
raise ValueError("The %s iterables field is unrecognized: %s"
840857
% (node.name, field))
841-
842-
# Assign to the standard form
843-
node.iterables = iterables
844858

845859
def _transpose_iterables(fields, values):
846860
"""
847-
Converts the given fields and tuple values into a list of
848-
iterable (field: value list) pairs, suitable for setting
849-
a node iterables property.
861+
Converts the given fields and tuple values into a standardized
862+
iterables value.
863+
864+
If the input values is a synchronize iterables dictionary, then
865+
the result is a (field, {key: values}) list.
866+
867+
Otherwise, the result is a list of (field: value list) pairs.
850868
"""
851-
return zip(fields, [filter(lambda(v): v != None, list(transpose))
852-
for transpose in zip(*values)])
869+
if isinstance(values, dict):
870+
transposed = {field: defaultdict(list) for field in fields}
871+
for key, tuples in values.iteritems():
872+
for kvals in tuples:
873+
for idx, val in enumerate(kvals):
874+
if val != None:
875+
transposed[fields[idx]][key].append(val)
876+
return transposed.items()
877+
else:
878+
return zip(fields, [filter(lambda(v): v != None, list(transpose))
879+
for transpose in zip(*values)])
853880

854881
def export_graph(graph_in, base_dir=None, show=False, use_execgraph=False,
855882
show_connectinfo=False, dotfilename='graph.dot', format='png',

0 commit comments

Comments
 (0)