6
6
from __future__ import (print_function , division , unicode_literals ,
7
7
absolute_import )
8
8
from builtins import open
9
+ import pytest
9
10
11
+ from .... import config
10
12
from ... import engine as pe
11
13
from ....interfaces import base as nib
12
14
from ....interfaces .utility import IdentityInterface , Function , Merge
@@ -45,19 +47,15 @@ class IncrementOutputSpec(nib.TraitedSpec):
45
47
output1 = nib .traits .Int (desc = 'ouput' )
46
48
47
49
48
- class IncrementInterface (nib .BaseInterface ):
50
+ class IncrementInterface (nib .SimpleInterface ):
49
51
input_spec = IncrementInputSpec
50
52
output_spec = IncrementOutputSpec
51
53
52
54
def _run_interface (self , runtime ):
53
55
runtime .returncode = 0
56
+ self ._results ['output1' ] = self .inputs .input1 + self .inputs .inc
54
57
return runtime
55
58
56
- def _list_outputs (self ):
57
- outputs = self ._outputs ().get ()
58
- outputs ['output1' ] = self .inputs .input1 + self .inputs .inc
59
- return outputs
60
-
61
59
62
60
_sums = []
63
61
@@ -73,23 +71,19 @@ class SumOutputSpec(nib.TraitedSpec):
73
71
operands = nib .traits .List (nib .traits .Int , desc = 'operands' )
74
72
75
73
76
- class SumInterface (nib .BaseInterface ):
74
+ class SumInterface (nib .SimpleInterface ):
77
75
input_spec = SumInputSpec
78
76
output_spec = SumOutputSpec
79
77
80
78
def _run_interface (self , runtime ):
81
- runtime .returncode = 0
82
- return runtime
83
-
84
- def _list_outputs (self ):
85
79
global _sum
86
80
global _sum_operands
87
- outputs = self . _outputs (). get ()
88
- outputs ['operands' ] = self .inputs .input1
89
- _sum_operands . append ( outputs [ 'operands' ] )
90
- outputs [ 'output1' ] = sum (self .inputs .input1 )
91
- _sums .append (outputs [ 'output1' ] )
92
- return outputs
81
+ runtime . returncode = 0
82
+ self . _results ['operands' ] = self .inputs .input1
83
+ self . _results [ 'output1' ] = sum ( self . inputs . input1 )
84
+ _sum_operands . append (self .inputs .input1 )
85
+ _sums .append (sum ( self . inputs . input1 ) )
86
+ return runtime
93
87
94
88
95
89
_set_len = None
@@ -148,35 +142,47 @@ def _list_outputs(self):
148
142
return outputs
149
143
150
144
151
- def test_join_expansion (tmpdir ):
145
+ @pytest .mark .parametrize ('needed_outputs' , [True , False ])
146
+ def test_join_expansion (tmpdir , needed_outputs ):
147
+ global _sums
148
+ global _sum_operands
149
+ global _products
152
150
tmpdir .chdir ()
153
151
152
+ # Clean up, just in case some other test modified them
153
+ _products = []
154
+ _sum_operands = []
155
+ _sums = []
156
+
157
+ config .set ('execution' , 'remove_unnecessary_outputs' , ['false' , 'true' ][needed_outputs ])
154
158
# Make the workflow.
155
159
wf = pe .Workflow (name = 'test' )
156
160
# the iterated input node
157
161
inputspec = pe .Node (IdentityInterface (fields = ['n' ]), name = 'inputspec' )
158
162
inputspec .iterables = [('n' , [1 , 2 ])]
159
163
# a pre-join node in the iterated path
160
164
pre_join1 = pe .Node (IncrementInterface (), name = 'pre_join1' )
161
- wf .connect (inputspec , 'n' , pre_join1 , 'input1' )
162
165
# another pre-join node in the iterated path
163
166
pre_join2 = pe .Node (IncrementInterface (), name = 'pre_join2' )
164
- wf .connect (pre_join1 , 'output1' , pre_join2 , 'input1' )
165
167
# the join node
166
168
join = pe .JoinNode (
167
169
SumInterface (),
168
170
joinsource = 'inputspec' ,
169
171
joinfield = 'input1' ,
170
172
name = 'join' )
171
- wf .connect (pre_join2 , 'output1' , join , 'input1' )
172
173
# an uniterated post-join node
173
174
post_join1 = pe .Node (IncrementInterface (), name = 'post_join1' )
174
- wf .connect (join , 'output1' , post_join1 , 'input1' )
175
175
# a post-join node in the iterated path
176
176
post_join2 = pe .Node (ProductInterface (), name = 'post_join2' )
177
- wf .connect (join , 'output1' , post_join2 , 'input1' )
178
- wf .connect (pre_join1 , 'output1' , post_join2 , 'input2' )
179
177
178
+ wf .connect ([
179
+ (inputspec , pre_join1 , [('n' , 'input1' )]),
180
+ (pre_join1 , pre_join2 , [('output1' , 'input1' )]),
181
+ (pre_join1 , post_join2 , [('output1' , 'input2' )]),
182
+ (pre_join2 , join , [('output1' , 'input1' )]),
183
+ (join , post_join1 , [('output1' , 'input1' )]),
184
+ (join , post_join2 , [('output1' , 'input1' )]),
185
+ ])
180
186
result = wf .run ()
181
187
182
188
# the two expanded pre-join predecessor nodes feed into one join node
@@ -185,8 +191,8 @@ def test_join_expansion(tmpdir):
185
191
# the expanded graph contains 2 * 2 = 4 iteration pre-join nodes, 1 join
186
192
# node, 1 non-iterated post-join node and 2 * 1 iteration post-join nodes.
187
193
# Nipype factors away the IdentityInterface.
188
- assert len (
189
- result . nodes ()) == 8 , "The number of expanded nodes is incorrect."
194
+ assert len (result . nodes ()) == 8 , "The number of expanded nodes is incorrect."
195
+
190
196
# the join Sum result is (1 + 1 + 1) + (2 + 1 + 1)
191
197
assert len (_sums ) == 1 , "The number of join outputs is incorrect"
192
198
assert _sums [
@@ -199,6 +205,7 @@ def test_join_expansion(tmpdir):
199
205
"The number of iterated post-join outputs is incorrect"
200
206
201
207
208
+
202
209
def test_node_joinsource (tmpdir ):
203
210
"""Test setting the joinsource to a Node."""
204
211
tmpdir .chdir ()
0 commit comments