Skip to content

Commit 37e48ec

Browse files
authored
Merge pull request #2830 from skoudoro/dipy-worflows-integration
[ENH] Add interfaces wrapping DIPY worflows
2 parents b0ce2e1 + ebe49d0 commit 37e48ec

File tree

7 files changed

+337
-4
lines changed

7 files changed

+337
-4
lines changed

doc/index.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
these packages within a single workflow. Nipype provides an environment
1717
that encourages interactive exploration of algorithms from different
1818
packages (e.g., ANTS_, SPM_, FSL_, FreeSurfer_, Camino_, MRtrix_, MNE_, AFNI_,
19-
Slicer_), eases the design of workflows within and between packages, and
19+
Slicer_, DIPY_), eases the design of workflows within and between packages, and
2020
reduces the learning curve necessary to use different packages. Nipype is
2121
creating a collaborative platform for neuroimaging software development
2222
in a high-level language and addressing limitations of existing pipeline

doc/links_names.txt

+1
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
.. _MRtrix3: http://www.mrtrix.org/
9898
.. _MNE: https://martinos.org/mne/index.html
9999
.. _ANTS: http://stnava.github.io/ANTs/
100+
.. _DIPY: http://dipy.org
100101

101102
.. General software
102103
.. _gcc: http://gcc.gnu.org

nipype/interfaces/dipy/base.py

+137-1
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
absolute_import)
55

66
import os.path as op
7+
import inspect
78
import numpy as np
89
from ... import logging
910
from ..base import (traits, File, isdefined, LibraryBaseInterface,
10-
BaseInterfaceInputSpec)
11+
BaseInterfaceInputSpec, TraitedSpec)
1112

1213
HAVE_DIPY = True
1314
try:
1415
import dipy
16+
from dipy.workflows.base import IntrospectiveArgumentParser
1517
except ImportError:
1618
HAVE_DIPY = False
1719

@@ -75,3 +77,137 @@ def _gen_filename(self, name, ext=None):
7577
ext = fext
7678

7779
return out_prefix + '_' + name + ext
80+
81+
82+
def convert_to_traits_type(dipy_type, is_file=False):
83+
"""Convert DIPY type to Traits type."""
84+
dipy_type = dipy_type.lower()
85+
is_mandatory = bool("optional" not in dipy_type)
86+
if "variable" in dipy_type and "string" in dipy_type:
87+
return traits.ListStr, is_mandatory
88+
elif "variable" in dipy_type and "int" in dipy_type:
89+
return traits.ListInt, is_mandatory
90+
elif "variable" in dipy_type and "float" in dipy_type:
91+
return traits.ListFloat, is_mandatory
92+
elif "variable" in dipy_type and "bool" in dipy_type:
93+
return traits.ListBool, is_mandatory
94+
elif "variable" in dipy_type and "complex" in dipy_type:
95+
return traits.ListComplex, is_mandatory
96+
elif "string" in dipy_type and not is_file:
97+
return traits.Str, is_mandatory
98+
elif "string" in dipy_type and is_file:
99+
return traits.File, is_mandatory
100+
elif "int" in dipy_type:
101+
return traits.Int, is_mandatory
102+
elif "float" in dipy_type:
103+
return traits.Float, is_mandatory
104+
elif "bool" in dipy_type:
105+
return traits.Bool, is_mandatory
106+
elif "complex" in dipy_type:
107+
return traits.Complex, is_mandatory
108+
else:
109+
msg = "Error during convert_to_traits_type({0}).".format(dipy_type) + \
110+
"Unknown DIPY type."
111+
raise IOError(msg)
112+
113+
114+
def create_interface_specs(class_name, params=None, BaseClass=TraitedSpec):
115+
"""Create IN/Out interface specifications dynamically.
116+
117+
Parameters
118+
----------
119+
class_name: str
120+
The future class name(e.g, (MyClassInSpec))
121+
params: list of tuple
122+
dipy argument list
123+
BaseClass: TraitedSpec object
124+
parent class
125+
126+
Returns
127+
-------
128+
newclass: object
129+
new nipype interface specification class
130+
131+
"""
132+
attr = {}
133+
if params is not None:
134+
for p in params:
135+
name, dipy_type, desc = p[0], p[1], p[2]
136+
is_file = bool("files" in name or "out_" in name)
137+
traits_type, is_mandatory = convert_to_traits_type(dipy_type,
138+
is_file)
139+
# print(name, dipy_type, desc, is_file, traits_type, is_mandatory)
140+
if BaseClass.__name__ == BaseInterfaceInputSpec.__name__:
141+
if len(p) > 3:
142+
attr[name] = traits_type(p[3], desc=desc[-1],
143+
usedefault=True,
144+
mandatory=is_mandatory)
145+
else:
146+
attr[name] = traits_type(desc=desc[-1],
147+
mandatory=is_mandatory)
148+
else:
149+
attr[name] = traits_type(p[3], desc=desc[-1], exists=True,
150+
usedefault=True,)
151+
152+
newclass = type(str(class_name), (BaseClass, ), attr)
153+
return newclass
154+
155+
156+
def dipy_to_nipype_interface(cls_name, dipy_flow, BaseClass=DipyBaseInterface):
157+
"""Construct a class in order to respect nipype interface specifications.
158+
159+
This convenient class factory convert a DIPY Workflow to a nipype
160+
interface.
161+
162+
Parameters
163+
----------
164+
cls_name: string
165+
new class name
166+
dipy_flow: Workflow class type.
167+
It should be any children class of `dipy.workflows.workflow.Worflow`
168+
BaseClass: object
169+
nipype instance object
170+
171+
Returns
172+
-------
173+
newclass: object
174+
new nipype interface specification class
175+
176+
"""
177+
parser = IntrospectiveArgumentParser()
178+
flow = dipy_flow()
179+
parser.add_workflow(flow)
180+
default_values = inspect.getargspec(flow.run).defaults
181+
optional_params = [args + (val,) for args, val in zip(parser.optional_parameters, default_values)]
182+
start = len(parser.optional_parameters) - len(parser.output_parameters)
183+
184+
output_parameters = [args + (val,) for args, val in zip(parser.output_parameters, default_values[start:])]
185+
input_parameters = parser.positional_parameters + optional_params
186+
187+
input_spec = create_interface_specs("{}InputSpec".format(cls_name),
188+
input_parameters,
189+
BaseClass=BaseInterfaceInputSpec)
190+
191+
output_spec = create_interface_specs("{}OutputSpec".format(cls_name),
192+
output_parameters,
193+
BaseClass=TraitedSpec)
194+
195+
def _run_interface(self, runtime):
196+
flow = dipy_flow()
197+
args = self.inputs.get()
198+
flow.run(**args)
199+
200+
def _list_outputs(self):
201+
outputs = self._outputs().get()
202+
out_dir = outputs.get("out_dir", ".")
203+
for key, values in outputs.items():
204+
outputs[key] = op.join(out_dir, values)
205+
206+
return outputs
207+
208+
newclass = type(str(cls_name), (BaseClass, ),
209+
{"input_spec": input_spec,
210+
"output_spec": output_spec,
211+
"_run_interface": _run_interface,
212+
"_list_outputs:": _list_outputs})
213+
return newclass

nipype/interfaces/dipy/reconstruction.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,30 @@
1313

1414
import numpy as np
1515
import nibabel as nb
16+
from distutils.version import LooseVersion
1617

1718
from ... import logging
1819
from ..base import TraitedSpec, File, traits, isdefined
19-
from .base import DipyDiffusionInterface, DipyBaseInterfaceInputSpec
20+
from .base import (DipyDiffusionInterface, DipyBaseInterfaceInputSpec,
21+
HAVE_DIPY, dipy_version, dipy_to_nipype_interface)
22+
2023

2124
IFLOGGER = logging.getLogger('nipype.interface')
2225

26+
if HAVE_DIPY and LooseVersion(dipy_version()) >= LooseVersion('0.15'):
27+
from dipy.workflows.reconst import (ReconstDkiFlow, ReconstCSAFlow,
28+
ReconstCSDFlow, ReconstMAPMRIFlow,
29+
ReconstDtiFlow)
30+
31+
DKIModel = dipy_to_nipype_interface("DKIModel", ReconstDkiFlow)
32+
MapmriModel = dipy_to_nipype_interface("MapmriModel", ReconstMAPMRIFlow)
33+
DTIModel = dipy_to_nipype_interface("DTIModel", ReconstDtiFlow)
34+
CSAModel = dipy_to_nipype_interface("CSAModel", ReconstCSAFlow)
35+
CSDModel = dipy_to_nipype_interface("CSDModel", ReconstCSDFlow)
36+
else:
37+
IFLOGGER.info("We advise you to upgrade DIPY version. This upgrade will"
38+
" activate DKIModel, MapmriModel, DTIModel, CSAModel, CSDModel.")
39+
2340

2441
class RESTOREInputSpec(DipyBaseInterfaceInputSpec):
2542
in_mask = File(exists=True, desc=('input mask in which compute tensors'))
+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
2+
from distutils.version import LooseVersion
3+
from ... import logging
4+
from .base import HAVE_DIPY, dipy_version, dipy_to_nipype_interface
5+
6+
IFLOGGER = logging.getLogger('nipype.interface')
7+
8+
if HAVE_DIPY and LooseVersion(dipy_version()) >= LooseVersion('0.15'):
9+
10+
from dipy.workflows.align import ResliceFlow, SlrWithQbxFlow
11+
12+
Reslice = dipy_to_nipype_interface("Reslice", ResliceFlow)
13+
StreamlineRegistration = dipy_to_nipype_interface("StreamlineRegistration",
14+
SlrWithQbxFlow)
15+
16+
else:
17+
IFLOGGER.info("We advise you to upgrade DIPY version. This upgrade will"
18+
" activate Reslice, StreamlineRegistration.")
+142
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import pytest
2+
from collections import namedtuple
3+
from ...base import traits, TraitedSpec, BaseInterfaceInputSpec
4+
from ..base import (convert_to_traits_type, create_interface_specs,
5+
dipy_to_nipype_interface, DipyBaseInterface, no_dipy)
6+
7+
8+
def test_convert_to_traits_type():
9+
Params = namedtuple("Params", "traits_type is_file")
10+
Res = namedtuple("Res", "traits_type is_mandatory")
11+
l_entries = [Params('variable string', False),
12+
Params('variable int', False),
13+
Params('variable float', False),
14+
Params('variable bool', False),
15+
Params('variable complex', False),
16+
Params('variable int, optional', False),
17+
Params('variable string, optional', False),
18+
Params('variable float, optional', False),
19+
Params('variable bool, optional', False),
20+
Params('variable complex, optional', False),
21+
Params('string', False), Params('int', False),
22+
Params('string', True), Params('float', False),
23+
Params('bool', False), Params('complex', False),
24+
Params('string, optional', False),
25+
Params('int, optional', False),
26+
Params('string, optional', True),
27+
Params('float, optional', False),
28+
Params('bool, optional', False),
29+
Params('complex, optional', False),
30+
]
31+
l_expected = [Res(traits.ListStr, True), Res(traits.ListInt, True),
32+
Res(traits.ListFloat, True), Res(traits.ListBool, True),
33+
Res(traits.ListComplex, True), Res(traits.ListInt, False),
34+
Res(traits.ListStr, False), Res(traits.ListFloat, False),
35+
Res(traits.ListBool, False), Res(traits.ListComplex, False),
36+
Res(traits.Str, True), Res(traits.Int, True),
37+
Res(traits.File, True), Res(traits.Float, True),
38+
Res(traits.Bool, True), Res(traits.Complex, True),
39+
Res(traits.Str, False), Res(traits.Int, False),
40+
Res(traits.File, False), Res(traits.Float, False),
41+
Res(traits.Bool, False), Res(traits.Complex, False),
42+
]
43+
44+
for entry, res in zip(l_entries, l_expected):
45+
traits_type, is_mandatory = convert_to_traits_type(entry.traits_type,
46+
entry.is_file)
47+
assert traits_type == res.traits_type
48+
assert is_mandatory == res.is_mandatory
49+
50+
with pytest.raises(IOError):
51+
convert_to_traits_type("file, optional")
52+
53+
54+
def test_create_interface_specs():
55+
new_interface = create_interface_specs("MyInterface")
56+
57+
assert new_interface.__base__ == TraitedSpec
58+
assert isinstance(new_interface(), TraitedSpec)
59+
assert new_interface.__name__ == "MyInterface"
60+
assert not new_interface().get()
61+
62+
new_interface = create_interface_specs("MyInterface",
63+
BaseClass=BaseInterfaceInputSpec)
64+
assert new_interface.__base__ == BaseInterfaceInputSpec
65+
assert isinstance(new_interface(), BaseInterfaceInputSpec)
66+
assert new_interface.__name__ == "MyInterface"
67+
assert not new_interface().get()
68+
69+
params = [("params1", "string", ["my description"]), ("params2_files", "string", ["my description @"]),
70+
("params3", "int, optional", ["useful option"]), ("out_params", "string", ["my out description"])]
71+
72+
new_interface = create_interface_specs("MyInterface", params=params,
73+
BaseClass=BaseInterfaceInputSpec)
74+
75+
assert new_interface.__base__ == BaseInterfaceInputSpec
76+
assert isinstance(new_interface(), BaseInterfaceInputSpec)
77+
assert new_interface.__name__ == "MyInterface"
78+
current_params = new_interface().get()
79+
assert len(current_params) == 4
80+
assert 'params1' in current_params.keys()
81+
assert 'params2_files' in current_params.keys()
82+
assert 'params3' in current_params.keys()
83+
assert 'out_params' in current_params.keys()
84+
85+
86+
@pytest.mark.skipif(no_dipy(), reason="DIPY is not installed")
87+
def test_dipy_to_nipype_interface():
88+
from dipy.workflows.workflow import Workflow
89+
90+
class DummyWorkflow(Workflow):
91+
92+
@classmethod
93+
def get_short_name(cls):
94+
return 'dwf1'
95+
96+
def run(self, in_files, param1=1, out_dir='', out_ref='out1.txt'):
97+
"""Workflow used to test basic workflows.
98+
99+
Parameters
100+
----------
101+
in_files : string
102+
fake input string param
103+
param1 : int, optional
104+
fake positional param (default 1)
105+
out_dir : string, optional
106+
fake output directory (default '')
107+
out_ref : string, optional
108+
fake out file (default out1.txt)
109+
110+
References
111+
-----------
112+
dummy references
113+
114+
"""
115+
return param1
116+
117+
new_specs = dipy_to_nipype_interface("MyModelSpec", DummyWorkflow)
118+
assert new_specs.__base__ == DipyBaseInterface
119+
assert isinstance(new_specs(), DipyBaseInterface)
120+
assert new_specs.__name__ == "MyModelSpec"
121+
assert hasattr(new_specs, 'input_spec')
122+
assert new_specs().input_spec.__base__ == BaseInterfaceInputSpec
123+
assert hasattr(new_specs, 'output_spec')
124+
assert new_specs().output_spec.__base__ == TraitedSpec
125+
assert hasattr(new_specs, '_run_interface')
126+
assert hasattr(new_specs, '_list_outputs')
127+
params_in = new_specs().inputs.get()
128+
params_out = new_specs()._outputs().get()
129+
assert len(params_in) == 4
130+
assert 'in_files' in params_in.keys()
131+
assert 'param1' in params_in.keys()
132+
assert 'out_dir' in params_out.keys()
133+
assert 'out_ref' in params_out.keys()
134+
135+
with pytest.raises(ValueError):
136+
new_specs().run()
137+
138+
139+
if __name__ == "__main__":
140+
test_convert_to_traits_type()
141+
test_create_interface_specs()
142+
test_dipy_to_nipype_interface()

0 commit comments

Comments
 (0)