Skip to content

Commit 0da0f2b

Browse files
committed
TST: Parametrize over config needed_outputs JoinNode expansion tests
1 parent 6f2f94f commit 0da0f2b

File tree

1 file changed

+33
-26
lines changed

1 file changed

+33
-26
lines changed

nipype/pipeline/engine/tests/test_join.py

+33-26
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from __future__ import (print_function, division, unicode_literals,
77
absolute_import)
88
from builtins import open
9+
import pytest
910

11+
from .... import config
1012
from ... import engine as pe
1113
from ....interfaces import base as nib
1214
from ....interfaces.utility import IdentityInterface, Function, Merge
@@ -45,19 +47,15 @@ class IncrementOutputSpec(nib.TraitedSpec):
4547
output1 = nib.traits.Int(desc='ouput')
4648

4749

48-
class IncrementInterface(nib.BaseInterface):
50+
class IncrementInterface(nib.SimpleInterface):
4951
input_spec = IncrementInputSpec
5052
output_spec = IncrementOutputSpec
5153

5254
def _run_interface(self, runtime):
5355
runtime.returncode = 0
56+
self._results['output1'] = self.inputs.input1 + self.inputs.inc
5457
return runtime
5558

56-
def _list_outputs(self):
57-
outputs = self._outputs().get()
58-
outputs['output1'] = self.inputs.input1 + self.inputs.inc
59-
return outputs
60-
6159

6260
_sums = []
6361

@@ -73,23 +71,19 @@ class SumOutputSpec(nib.TraitedSpec):
7371
operands = nib.traits.List(nib.traits.Int, desc='operands')
7472

7573

76-
class SumInterface(nib.BaseInterface):
74+
class SumInterface(nib.SimpleInterface):
7775
input_spec = SumInputSpec
7876
output_spec = SumOutputSpec
7977

8078
def _run_interface(self, runtime):
81-
runtime.returncode = 0
82-
return runtime
83-
84-
def _list_outputs(self):
8579
global _sum
8680
global _sum_operands
87-
outputs = self._outputs().get()
88-
outputs['operands'] = self.inputs.input1
89-
_sum_operands.append(outputs['operands'])
90-
outputs['output1'] = sum(self.inputs.input1)
91-
_sums.append(outputs['output1'])
92-
return outputs
81+
runtime.returncode = 0
82+
self._results['operands'] = self.inputs.input1
83+
self._results['output1'] = sum(self.inputs.input1)
84+
_sum_operands.append(self.inputs.input1)
85+
_sums.append(sum(self.inputs.input1))
86+
return runtime
9387

9488

9589
_set_len = None
@@ -148,35 +142,47 @@ def _list_outputs(self):
148142
return outputs
149143

150144

151-
def test_join_expansion(tmpdir):
145+
@pytest.mark.parametrize('needed_outputs', [True, False])
146+
def test_join_expansion(tmpdir, needed_outputs):
147+
global _sums
148+
global _sum_operands
149+
global _products
152150
tmpdir.chdir()
153151

152+
# Clean up, just in case some other test modified them
153+
_products = []
154+
_sum_operands = []
155+
_sums = []
156+
157+
config.set('execution', 'remove_unnecessary_outputs', ['false', 'true'][needed_outputs])
154158
# Make the workflow.
155159
wf = pe.Workflow(name='test')
156160
# the iterated input node
157161
inputspec = pe.Node(IdentityInterface(fields=['n']), name='inputspec')
158162
inputspec.iterables = [('n', [1, 2])]
159163
# a pre-join node in the iterated path
160164
pre_join1 = pe.Node(IncrementInterface(), name='pre_join1')
161-
wf.connect(inputspec, 'n', pre_join1, 'input1')
162165
# another pre-join node in the iterated path
163166
pre_join2 = pe.Node(IncrementInterface(), name='pre_join2')
164-
wf.connect(pre_join1, 'output1', pre_join2, 'input1')
165167
# the join node
166168
join = pe.JoinNode(
167169
SumInterface(),
168170
joinsource='inputspec',
169171
joinfield='input1',
170172
name='join')
171-
wf.connect(pre_join2, 'output1', join, 'input1')
172173
# an uniterated post-join node
173174
post_join1 = pe.Node(IncrementInterface(), name='post_join1')
174-
wf.connect(join, 'output1', post_join1, 'input1')
175175
# a post-join node in the iterated path
176176
post_join2 = pe.Node(ProductInterface(), name='post_join2')
177-
wf.connect(join, 'output1', post_join2, 'input1')
178-
wf.connect(pre_join1, 'output1', post_join2, 'input2')
179177

178+
wf.connect([
179+
(inputspec, pre_join1, [('n', 'input1')]),
180+
(pre_join1, pre_join2, [('output1', 'input1')]),
181+
(pre_join1, post_join2, [('output1', 'input2')]),
182+
(pre_join2, join, [('output1', 'input1')]),
183+
(join, post_join1, [('output1', 'input1')]),
184+
(join, post_join2, [('output1', 'input1')]),
185+
])
180186
result = wf.run()
181187

182188
# the two expanded pre-join predecessor nodes feed into one join node
@@ -185,8 +191,8 @@ def test_join_expansion(tmpdir):
185191
# the expanded graph contains 2 * 2 = 4 iteration pre-join nodes, 1 join
186192
# node, 1 non-iterated post-join node and 2 * 1 iteration post-join nodes.
187193
# Nipype factors away the IdentityInterface.
188-
assert len(
189-
result.nodes()) == 8, "The number of expanded nodes is incorrect."
194+
assert len(result.nodes()) == 8, "The number of expanded nodes is incorrect."
195+
190196
# the join Sum result is (1 + 1 + 1) + (2 + 1 + 1)
191197
assert len(_sums) == 1, "The number of join outputs is incorrect"
192198
assert _sums[
@@ -199,6 +205,7 @@ def test_join_expansion(tmpdir):
199205
"The number of iterated post-join outputs is incorrect"
200206

201207

208+
202209
def test_node_joinsource(tmpdir):
203210
"""Test setting the joinsource to a Node."""
204211
tmpdir.chdir()

0 commit comments

Comments
 (0)