Skip to content

ENH: Add ConstrainedSphericalDeconvolution interface to replace EstimateFOD for MRtrix3's dwi2fod #3176

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 24, 2020
2 changes: 1 addition & 1 deletion nipype/interfaces/mrtrix3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@
DWIBiasCorrect,
)
from .tracking import Tractography
from .reconst import FitTensor, EstimateFOD
from .reconst import FitTensor, EstimateFOD, ConstrainedSphericalDeconvolution
from .connectivity import LabelConfig, LabelConvert, BuildConnectome
57 changes: 53 additions & 4 deletions nipype/interfaces/mrtrix3/reconst.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,24 @@ class EstimateFOD(MRTrix3Base):
"""
Estimate fibre orientation distributions from diffusion data using spherical deconvolution

.. warning::

The CSD algorithm does not work as intended, but fixing it in this interface could break
existing workflows. This interface has been superseded by
:py:class:`.ConstrainedSphericalDecomposition`.

Example
-------

>>> import nipype.interfaces.mrtrix3 as mrt
>>> fod = mrt.EstimateFOD()
>>> fod.inputs.algorithm = 'csd'
>>> fod.inputs.algorithm = 'msmt_csd'
>>> fod.inputs.in_file = 'dwi.mif'
>>> fod.inputs.wm_txt = 'wm.txt'
>>> fod.inputs.grad_fsl = ('bvecs', 'bvals')
>>> fod.cmdline # doctest: +ELLIPSIS
'dwi2fod -fslgrad bvecs bvals -lmax 8 csd dwi.mif wm.txt wm.mif gm.mif csf.mif'
>>> fod.run() # doctest: +SKIP
>>> fod.cmdline
'dwi2fod -fslgrad bvecs bvals -lmax 8 msmt_csd dwi.mif wm.txt wm.mif gm.mif csf.mif'
>>> fod.run() # doctest: +SKIP
"""

_cmd = "dwi2fod"
Expand All @@ -182,3 +188,46 @@ def _list_outputs(self):
if self.inputs.csf_odf != Undefined:
outputs["csf_odf"] = op.abspath(self.inputs.csf_odf)
return outputs


class ConstrainedSphericalDeconvolutionInputSpec(EstimateFODInputSpec):
gm_odf = File(argstr="%s", position=-3, desc="output GM ODF")
csf_odf = File(argstr="%s", position=-1, desc="output CSF ODF")
max_sh = InputMultiObject(
traits.Int,
argstr="-lmax %s",
sep=",",
desc=(
"maximum harmonic degree of response function - single value for single-shell response, list for multi-shell response"
),
)


class ConstrainedSphericalDeconvolution(EstimateFOD):
"""
Estimate fibre orientation distributions from diffusion data using spherical deconvolution

This interface supersedes :py:class:`.EstimateFOD`.
The old interface has contained a bug when using the CSD algorithm as opposed to the MSMT CSD
algorithm, but fixing it could potentially break existing workflows. The new interface works
the same, but does not populate the following inputs by default:

* ``gm_odf``
* ``csf_odf``
* ``max_sh``

Example
-------

>>> import nipype.interfaces.mrtrix3 as mrt
>>> fod = mrt.ConstrainedSphericalDeconvolution()
>>> fod.inputs.algorithm = 'csd'
>>> fod.inputs.in_file = 'dwi.mif'
>>> fod.inputs.wm_txt = 'wm.txt'
>>> fod.inputs.grad_fsl = ('bvecs', 'bvals')
>>> fod.cmdline
'dwi2fod -fslgrad bvecs bvals csd dwi.mif wm.txt wm.mif'
>>> fod.run() # doctest: +SKIP
"""

input_spec = ConstrainedSphericalDeconvolutionInputSpec
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# AUTO-GENERATED by tools/checkspecs.py - DO NOT EDIT
from ..reconst import ConstrainedSphericalDeconvolution


def test_ConstrainedSphericalDeconvolution_inputs():
input_map = dict(
algorithm=dict(argstr="%s", mandatory=True, position=-8,),
args=dict(argstr="%s",),
bval_scale=dict(argstr="-bvalue_scaling %s",),
csf_odf=dict(argstr="%s", extensions=None, position=-1,),
csf_txt=dict(argstr="%s", extensions=None, position=-2,),
environ=dict(nohash=True, usedefault=True,),
gm_odf=dict(argstr="%s", extensions=None, position=-3,),
gm_txt=dict(argstr="%s", extensions=None, position=-4,),
grad_file=dict(argstr="-grad %s", extensions=None, xor=["grad_fsl"],),
grad_fsl=dict(argstr="-fslgrad %s %s", xor=["grad_file"],),
in_bval=dict(extensions=None,),
in_bvec=dict(argstr="-fslgrad %s %s", extensions=None,),
in_dirs=dict(argstr="-directions %s", extensions=None,),
in_file=dict(argstr="%s", extensions=None, mandatory=True, position=-7,),
mask_file=dict(argstr="-mask %s", extensions=None,),
max_sh=dict(argstr="-lmax %s", sep=",",),
nthreads=dict(argstr="-nthreads %d", nohash=True,),
shell=dict(argstr="-shell %s", sep=",",),
wm_odf=dict(
argstr="%s", extensions=None, mandatory=True, position=-5, usedefault=True,
),
wm_txt=dict(argstr="%s", extensions=None, mandatory=True, position=-6,),
)
inputs = ConstrainedSphericalDeconvolution.input_spec()

for key, metadata in list(input_map.items()):
for metakey, value in list(metadata.items()):
assert getattr(inputs.traits()[key], metakey) == value


def test_ConstrainedSphericalDeconvolution_outputs():
output_map = dict(
csf_odf=dict(argstr="%s", extensions=None,),
gm_odf=dict(argstr="%s", extensions=None,),
wm_odf=dict(argstr="%s", extensions=None,),
)
outputs = ConstrainedSphericalDeconvolution.output_spec()

for key, metadata in list(output_map.items()):
for metakey, value in list(metadata.items()):
assert getattr(outputs.traits()[key], metakey) == value
2 changes: 1 addition & 1 deletion tools/checkspecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def test_specs(self, uri):
and "xor" not in trait.__dict__
):
if (
trait.trait_type.__class__.__name__ is "Range"
trait.trait_type.__class__.__name__ == "Range"
and trait.default == trait.trait_type._low
):
continue
Expand Down