6
6
from __future__ import (print_function , division , unicode_literals ,
7
7
absolute_import )
8
8
from builtins import open
9
+ import pytest
9
10
11
+ from .... import config
10
12
from ... import engine as pe
11
13
from ....interfaces import base as nib
12
14
from ....interfaces .utility import IdentityInterface , Function , Merge
@@ -45,19 +47,15 @@ class IncrementOutputSpec(nib.TraitedSpec):
45
47
output1 = nib .traits .Int (desc = 'ouput' )
46
48
47
49
48
- class IncrementInterface (nib .BaseInterface ):
50
+ class IncrementInterface (nib .SimpleInterface ):
49
51
input_spec = IncrementInputSpec
50
52
output_spec = IncrementOutputSpec
51
53
52
54
def _run_interface (self , runtime ):
53
55
runtime .returncode = 0
56
+ self ._results ['output1' ] = self .inputs .input1 + self .inputs .inc
54
57
return runtime
55
58
56
- def _list_outputs (self ):
57
- outputs = self ._outputs ().get ()
58
- outputs ['output1' ] = self .inputs .input1 + self .inputs .inc
59
- return outputs
60
-
61
59
62
60
_sums = []
63
61
@@ -73,23 +71,19 @@ class SumOutputSpec(nib.TraitedSpec):
73
71
operands = nib .traits .List (nib .traits .Int , desc = 'operands' )
74
72
75
73
76
- class SumInterface (nib .BaseInterface ):
74
+ class SumInterface (nib .SimpleInterface ):
77
75
input_spec = SumInputSpec
78
76
output_spec = SumOutputSpec
79
77
80
78
def _run_interface (self , runtime ):
81
- runtime .returncode = 0
82
- return runtime
83
-
84
- def _list_outputs (self ):
85
79
global _sum
86
80
global _sum_operands
87
- outputs = self . _outputs (). get ()
88
- outputs ['operands' ] = self .inputs .input1
89
- _sum_operands . append ( outputs [ 'operands' ] )
90
- outputs [ 'output1' ] = sum (self .inputs .input1 )
91
- _sums .append (outputs [ 'output1' ] )
92
- return outputs
81
+ runtime . returncode = 0
82
+ self . _results ['operands' ] = self .inputs .input1
83
+ self . _results [ 'output1' ] = sum ( self . inputs . input1 )
84
+ _sum_operands . append (self .inputs .input1 )
85
+ _sums .append (sum ( self . inputs . input1 ) )
86
+ return runtime
93
87
94
88
95
89
_set_len = None
@@ -148,35 +142,48 @@ def _list_outputs(self):
148
142
return outputs
149
143
150
144
151
- def test_join_expansion (tmpdir ):
145
+ @pytest .mark .parametrize ('needed_outputs' , ['true' , 'false' ])
146
+ def test_join_expansion (tmpdir , needed_outputs ):
147
+ global _sums
148
+ global _sum_operands
149
+ global _products
152
150
tmpdir .chdir ()
153
151
152
+ # Clean up, just in case some other test modified them
153
+ _products = []
154
+ _sum_operands = []
155
+ _sums = []
156
+
157
+ prev_state = config .get ('execution' , 'remove_unnecessary_outputs' )
158
+ config .set ('execution' , 'remove_unnecessary_outputs' , needed_outputs )
154
159
# Make the workflow.
155
160
wf = pe .Workflow (name = 'test' )
156
161
# the iterated input node
157
162
inputspec = pe .Node (IdentityInterface (fields = ['n' ]), name = 'inputspec' )
158
163
inputspec .iterables = [('n' , [1 , 2 ])]
159
164
# a pre-join node in the iterated path
160
165
pre_join1 = pe .Node (IncrementInterface (), name = 'pre_join1' )
161
- wf .connect (inputspec , 'n' , pre_join1 , 'input1' )
162
166
# another pre-join node in the iterated path
163
167
pre_join2 = pe .Node (IncrementInterface (), name = 'pre_join2' )
164
- wf .connect (pre_join1 , 'output1' , pre_join2 , 'input1' )
165
168
# the join node
166
169
join = pe .JoinNode (
167
170
SumInterface (),
168
171
joinsource = 'inputspec' ,
169
172
joinfield = 'input1' ,
170
173
name = 'join' )
171
- wf .connect (pre_join2 , 'output1' , join , 'input1' )
172
174
# an uniterated post-join node
173
175
post_join1 = pe .Node (IncrementInterface (), name = 'post_join1' )
174
- wf .connect (join , 'output1' , post_join1 , 'input1' )
175
176
# a post-join node in the iterated path
176
177
post_join2 = pe .Node (ProductInterface (), name = 'post_join2' )
177
- wf .connect (join , 'output1' , post_join2 , 'input1' )
178
- wf .connect (pre_join1 , 'output1' , post_join2 , 'input2' )
179
178
179
+ wf .connect ([
180
+ (inputspec , pre_join1 , [('n' , 'input1' )]),
181
+ (pre_join1 , pre_join2 , [('output1' , 'input1' )]),
182
+ (pre_join1 , post_join2 , [('output1' , 'input2' )]),
183
+ (pre_join2 , join , [('output1' , 'input1' )]),
184
+ (join , post_join1 , [('output1' , 'input1' )]),
185
+ (join , post_join2 , [('output1' , 'input1' )]),
186
+ ])
180
187
result = wf .run ()
181
188
182
189
# the two expanded pre-join predecessor nodes feed into one join node
@@ -185,8 +192,8 @@ def test_join_expansion(tmpdir):
185
192
# the expanded graph contains 2 * 2 = 4 iteration pre-join nodes, 1 join
186
193
# node, 1 non-iterated post-join node and 2 * 1 iteration post-join nodes.
187
194
# Nipype factors away the IdentityInterface.
188
- assert len (
189
- result . nodes ()) == 8 , "The number of expanded nodes is incorrect."
195
+ assert len (result . nodes ()) == 8 , "The number of expanded nodes is incorrect."
196
+
190
197
# the join Sum result is (1 + 1 + 1) + (2 + 1 + 1)
191
198
assert len (_sums ) == 1 , "The number of join outputs is incorrect"
192
199
assert _sums [
@@ -197,6 +204,7 @@ def test_join_expansion(tmpdir):
197
204
# there are two iterations of the post-join node in the iterable path
198
205
assert len (_products ) == 2 ,\
199
206
"The number of iterated post-join outputs is incorrect"
207
+ config .set ('execution' , 'remove_unnecessary_outputs' , prev_state )
200
208
201
209
202
210
def test_node_joinsource (tmpdir ):
0 commit comments