Skip to content

Commit 7262b24

Browse files
authored
Merge pull request #2981 from oesteban/tst/parametrize-join-expansion
TST: Parametrize JoinNode expansion tests over config ``needed_outputs``
2 parents 5965d45 + 2e9ecc1 commit 7262b24

File tree

1 file changed

+34
-26
lines changed

1 file changed

+34
-26
lines changed

nipype/pipeline/engine/tests/test_join.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from __future__ import (print_function, division, unicode_literals,
77
absolute_import)
88
from builtins import open
9+
import pytest
910

11+
from .... import config
1012
from ... import engine as pe
1113
from ....interfaces import base as nib
1214
from ....interfaces.utility import IdentityInterface, Function, Merge
@@ -45,19 +47,15 @@ class IncrementOutputSpec(nib.TraitedSpec):
4547
output1 = nib.traits.Int(desc='ouput')
4648

4749

48-
class IncrementInterface(nib.BaseInterface):
50+
class IncrementInterface(nib.SimpleInterface):
4951
input_spec = IncrementInputSpec
5052
output_spec = IncrementOutputSpec
5153

5254
def _run_interface(self, runtime):
5355
runtime.returncode = 0
56+
self._results['output1'] = self.inputs.input1 + self.inputs.inc
5457
return runtime
5558

56-
def _list_outputs(self):
57-
outputs = self._outputs().get()
58-
outputs['output1'] = self.inputs.input1 + self.inputs.inc
59-
return outputs
60-
6159

6260
_sums = []
6361

@@ -73,23 +71,19 @@ class SumOutputSpec(nib.TraitedSpec):
7371
operands = nib.traits.List(nib.traits.Int, desc='operands')
7472

7573

76-
class SumInterface(nib.BaseInterface):
74+
class SumInterface(nib.SimpleInterface):
7775
input_spec = SumInputSpec
7876
output_spec = SumOutputSpec
7977

8078
def _run_interface(self, runtime):
81-
runtime.returncode = 0
82-
return runtime
83-
84-
def _list_outputs(self):
8579
global _sum
8680
global _sum_operands
87-
outputs = self._outputs().get()
88-
outputs['operands'] = self.inputs.input1
89-
_sum_operands.append(outputs['operands'])
90-
outputs['output1'] = sum(self.inputs.input1)
91-
_sums.append(outputs['output1'])
92-
return outputs
81+
runtime.returncode = 0
82+
self._results['operands'] = self.inputs.input1
83+
self._results['output1'] = sum(self.inputs.input1)
84+
_sum_operands.append(self.inputs.input1)
85+
_sums.append(sum(self.inputs.input1))
86+
return runtime
9387

9488

9589
_set_len = None
@@ -148,35 +142,48 @@ def _list_outputs(self):
148142
return outputs
149143

150144

151-
def test_join_expansion(tmpdir):
145+
@pytest.mark.parametrize('needed_outputs', ['true', 'false'])
146+
def test_join_expansion(tmpdir, needed_outputs):
147+
global _sums
148+
global _sum_operands
149+
global _products
152150
tmpdir.chdir()
153151

152+
# Clean up, just in case some other test modified them
153+
_products = []
154+
_sum_operands = []
155+
_sums = []
156+
157+
prev_state = config.get('execution', 'remove_unnecessary_outputs')
158+
config.set('execution', 'remove_unnecessary_outputs', needed_outputs)
154159
# Make the workflow.
155160
wf = pe.Workflow(name='test')
156161
# the iterated input node
157162
inputspec = pe.Node(IdentityInterface(fields=['n']), name='inputspec')
158163
inputspec.iterables = [('n', [1, 2])]
159164
# a pre-join node in the iterated path
160165
pre_join1 = pe.Node(IncrementInterface(), name='pre_join1')
161-
wf.connect(inputspec, 'n', pre_join1, 'input1')
162166
# another pre-join node in the iterated path
163167
pre_join2 = pe.Node(IncrementInterface(), name='pre_join2')
164-
wf.connect(pre_join1, 'output1', pre_join2, 'input1')
165168
# the join node
166169
join = pe.JoinNode(
167170
SumInterface(),
168171
joinsource='inputspec',
169172
joinfield='input1',
170173
name='join')
171-
wf.connect(pre_join2, 'output1', join, 'input1')
172174
# an uniterated post-join node
173175
post_join1 = pe.Node(IncrementInterface(), name='post_join1')
174-
wf.connect(join, 'output1', post_join1, 'input1')
175176
# a post-join node in the iterated path
176177
post_join2 = pe.Node(ProductInterface(), name='post_join2')
177-
wf.connect(join, 'output1', post_join2, 'input1')
178-
wf.connect(pre_join1, 'output1', post_join2, 'input2')
179178

179+
wf.connect([
180+
(inputspec, pre_join1, [('n', 'input1')]),
181+
(pre_join1, pre_join2, [('output1', 'input1')]),
182+
(pre_join1, post_join2, [('output1', 'input2')]),
183+
(pre_join2, join, [('output1', 'input1')]),
184+
(join, post_join1, [('output1', 'input1')]),
185+
(join, post_join2, [('output1', 'input1')]),
186+
])
180187
result = wf.run()
181188

182189
# the two expanded pre-join predecessor nodes feed into one join node
@@ -185,8 +192,8 @@ def test_join_expansion(tmpdir):
185192
# the expanded graph contains 2 * 2 = 4 iteration pre-join nodes, 1 join
186193
# node, 1 non-iterated post-join node and 2 * 1 iteration post-join nodes.
187194
# Nipype factors away the IdentityInterface.
188-
assert len(
189-
result.nodes()) == 8, "The number of expanded nodes is incorrect."
195+
assert len(result.nodes()) == 8, "The number of expanded nodes is incorrect."
196+
190197
# the join Sum result is (1 + 1 + 1) + (2 + 1 + 1)
191198
assert len(_sums) == 1, "The number of join outputs is incorrect"
192199
assert _sums[
@@ -197,6 +204,7 @@ def test_join_expansion(tmpdir):
197204
# there are two iterations of the post-join node in the iterable path
198205
assert len(_products) == 2,\
199206
"The number of iterated post-join outputs is incorrect"
207+
config.set('execution', 'remove_unnecessary_outputs', prev_state)
200208

201209

202210
def test_node_joinsource(tmpdir):

0 commit comments

Comments
 (0)