diff --git a/nipype/info.py b/nipype/info.py index 1cf361c40c..da3bc8311f 100644 --- a/nipype/info.py +++ b/nipype/info.py @@ -108,7 +108,7 @@ def get_nipype_gitversion(): SCIPY_MIN_VERSION = '0.14' TRAITS_MIN_VERSION = '4.6' DATEUTIL_MIN_VERSION = '2.2' -PYTEST_MIN_VERSION = '3.0' +PYTEST_MIN_VERSION = '3.4' FUTURE_MIN_VERSION = '0.16.0' SIMPLEJSON_MIN_VERSION = '3.8.0' PROV_VERSION = '1.5.2' @@ -159,7 +159,12 @@ def get_nipype_gitversion(): if sys.version_info <= (3, 4): REQUIRES.append('configparser') -TESTS_REQUIRES = ['pytest-cov', 'codecov', 'pytest-env', 'coverage<5'] +TESTS_REQUIRES = [ + 'pytest-cov', + 'pytest-env', + 'codecov', + 'coverage<5', +] EXTRA_REQUIRES = { 'doc': ['Sphinx>=1.4', 'numpydoc', 'matplotlib', 'pydotplus', 'pydot>=1.2.3'], diff --git a/nipype/interfaces/ants/tests/test_extra_Registration.py b/nipype/interfaces/ants/tests/test_extra_Registration.py index 745b825c65..c65bb445be 100644 --- a/nipype/interfaces/ants/tests/test_extra_Registration.py +++ b/nipype/interfaces/ants/tests/test_extra_Registration.py @@ -1,9 +1,10 @@ # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: from __future__ import unicode_literals -from nipype.interfaces.ants import registration import os import pytest +from nipype.interfaces.ants import registration +from nipype.utils.errors import MandatoryInputError def test_ants_mand(tmpdir): @@ -17,6 +18,6 @@ def test_ants_mand(tmpdir): ants.inputs.fixed_image = [os.path.join(datadir, 'T1.nii')] ants.inputs.metric = ['MI'] - with pytest.raises(ValueError) as er: + with pytest.raises(MandatoryInputError) as er: ants.run() - assert "ANTS requires a value for input 'radius'" in str(er.value) + assert 'Interface "ANTS" requires a value for input radius.' in str(er.value) diff --git a/nipype/interfaces/base/core.py b/nipype/interfaces/base/core.py index 6177b449f9..6f36dcf4a5 100644 --- a/nipype/interfaces/base/core.py +++ b/nipype/interfaces/base/core.py @@ -29,7 +29,7 @@ import simplejson as json from dateutil.parser import parse as parseutc -from ... import config, logging, LooseVersion +from ... import config, logging from ...utils.provenance import write_provenance from ...utils.misc import trim, str2bool, rgetcwd from ...utils.filemanip import (FileNotFoundError, split_filename, @@ -41,7 +41,7 @@ from .traits_extension import traits, isdefined, TraitError from .specs import (BaseInterfaceInputSpec, CommandLineInputSpec, StdOutCommandLineInputSpec, MpiCommandLineInputSpec, - get_filecopy_info) + get_filecopy_info, check_mandatory_inputs, check_version) from .support import (Bunch, InterfaceResult, NipypeInterfaceError) from future import standard_library @@ -335,93 +335,6 @@ def _outputs(self): return outputs - def _check_requires(self, spec, name, value): - """ check if required inputs are satisfied - """ - if spec.requires: - values = [ - not isdefined(getattr(self.inputs, field)) - for field in spec.requires - ] - if any(values) and isdefined(value): - msg = ("%s requires a value for input '%s' because one of %s " - "is set. For a list of required inputs, see %s.help()" % - (self.__class__.__name__, name, - ', '.join(spec.requires), self.__class__.__name__)) - raise ValueError(msg) - - def _check_xor(self, spec, name, value): - """ check if mutually exclusive inputs are satisfied - """ - if spec.xor: - values = [ - isdefined(getattr(self.inputs, field)) for field in spec.xor - ] - if not any(values) and not isdefined(value): - msg = ("%s requires a value for one of the inputs '%s'. " - "For a list of required inputs, see %s.help()" % - (self.__class__.__name__, ', '.join(spec.xor), - self.__class__.__name__)) - raise ValueError(msg) - - def _check_mandatory_inputs(self): - """ Raises an exception if a mandatory input is Undefined - """ - for name, spec in list(self.inputs.traits(mandatory=True).items()): - value = getattr(self.inputs, name) - self._check_xor(spec, name, value) - if not isdefined(value) and spec.xor is None: - msg = ("%s requires a value for input '%s'. " - "For a list of required inputs, see %s.help()" % - (self.__class__.__name__, name, - self.__class__.__name__)) - raise ValueError(msg) - if isdefined(value): - self._check_requires(spec, name, value) - for name, spec in list( - self.inputs.traits(mandatory=None, transient=None).items()): - self._check_requires(spec, name, getattr(self.inputs, name)) - - def _check_version_requirements(self, trait_object, raise_exception=True): - """ Raises an exception on version mismatch - """ - unavailable_traits = [] - # check minimum version - check = dict(min_ver=lambda t: t is not None) - names = trait_object.trait_names(**check) - - if names and self.version: - version = LooseVersion(str(self.version)) - for name in names: - min_ver = LooseVersion( - str(trait_object.traits()[name].min_ver)) - if min_ver > version: - unavailable_traits.append(name) - if not isdefined(getattr(trait_object, name)): - continue - if raise_exception: - raise Exception( - 'Trait %s (%s) (version %s < required %s)' % - (name, self.__class__.__name__, version, min_ver)) - - # check maximum version - check = dict(max_ver=lambda t: t is not None) - names = trait_object.trait_names(**check) - if names and self.version: - version = LooseVersion(str(self.version)) - for name in names: - max_ver = LooseVersion( - str(trait_object.traits()[name].max_ver)) - if max_ver < version: - unavailable_traits.append(name) - if not isdefined(getattr(trait_object, name)): - continue - if raise_exception: - raise Exception( - 'Trait %s (%s) (version %s > required %s)' % - (name, self.__class__.__name__, version, max_ver)) - return unavailable_traits - def _run_interface(self, runtime): """ Core function that executes interface """ @@ -466,8 +379,8 @@ def run(self, cwd=None, ignore_exception=None, **inputs): enable_rm = config.resource_monitor and self.resource_monitor self.inputs.trait_set(**inputs) - self._check_mandatory_inputs() - self._check_version_requirements(self.inputs) + check_mandatory_inputs(self.inputs) + check_version(self.inputs, version=self.version) interface = self.__class__ self._duecredit_cite() @@ -593,8 +506,8 @@ def aggregate_outputs(self, runtime=None, needed_outputs=None): if predicted_outputs: _unavailable_outputs = [] if outputs: - _unavailable_outputs = \ - self._check_version_requirements(self._outputs()) + _unavailable_outputs = check_version( + self._outputs(), self.version) for key, val in list(predicted_outputs.items()): if needed_outputs and key not in needed_outputs: continue @@ -810,7 +723,14 @@ def cmd(self): def cmdline(self): """ `command` plus any arguments (args) validates arguments and generates command line""" - self._check_mandatory_inputs() + if not check_mandatory_inputs(self.inputs, raise_exc=False): + iflogger.warning( + 'Command line could not be generated because some inputs ' + 'are not valid. Please make sure all mandatory inputs, ' + 'required inputs and mutually-exclusive inputs are set ' + 'or in a sane state.') + return None + allargs = [self._cmd_prefix + self.cmd] + self._parse_inputs() return ' '.join(allargs) diff --git a/nipype/interfaces/base/specs.py b/nipype/interfaces/base/specs.py index dbbc816dc9..d4c828699d 100644 --- a/nipype/interfaces/base/specs.py +++ b/nipype/interfaces/base/specs.py @@ -20,6 +20,9 @@ from packaging.version import Version from ...utils.filemanip import md5, hash_infile, hash_timestamp, to_str +from ...utils.errors import ( + MandatoryInputError, MutuallyExclusiveInputError, RequiredInputError, + VersionIOError) from .traits_extension import ( traits, Undefined, @@ -115,9 +118,8 @@ def _xor_warn(self, obj, name, old, new): trait_change_notify=False, **{ '%s' % name: Undefined }) - msg = ('Input "%s" is mutually exclusive with input "%s", ' - 'which is already set') % (name, trait_name) - raise IOError(msg) + raise MutuallyExclusiveInputError( + self, name, name_other=trait_name) def _deprecated_warn(self, obj, name, old, new): """Checks if a user assigns a value to a deprecated trait @@ -394,3 +396,117 @@ def get_filecopy_info(cls): for name, spec in sorted(inputs.traits(**metadata).items()): info.append(dict(key=name, copy=spec.copyfile)) return info + +def check_requires(inputs, requires): + """check if required inputs are satisfied + """ + if not requires: + return True + + # Check value and all required inputs' values defined + values = [isdefined(getattr(inputs, field)) + for field in requires] + return all(values) + +def check_xor(inputs, name, xor): + """ check if mutually exclusive inputs are satisfied + """ + if len(xor) == 0: + return True + + values = [isdefined(getattr(inputs, name))] + values += [any([isdefined(getattr(inputs, field)) + for field in xor])] + return sum(values) + +def check_mandatory_inputs(inputs, raise_exc=True): + """ Raises an exception if a mandatory input is Undefined + """ + # Check mandatory, not xor-ed inputs. + for name, spec in list(inputs.traits(mandatory=True).items()): + value = getattr(inputs, name) + # Mandatory field is defined, check xor'ed inputs + xor = spec.xor or [] + has_xor = bool(xor) + has_value = isdefined(value) + + # Simplest case: no xor metadata and not defined + if not has_xor and not has_value: + if raise_exc: + raise MandatoryInputError(inputs, name) + return False + + xor = set(list(xor) if isinstance(xor, (list, tuple)) + else [xor]) + xor.discard(name) + xor = list(xor) + cxor = check_xor(inputs, name, xor) + if cxor != 1: + if raise_exc: + raise MutuallyExclusiveInputError( + inputs, name, values_defined=cxor) + return False + + # Check whether mandatory inputs require others + if has_value and not check_requires(inputs, spec.requires): + if raise_exc: + raise RequiredInputError(inputs, name) + return False + + # Check requirements of non-mandatory inputs + for name, spec in list( + inputs.traits(mandatory=None, transient=None).items()): + value = getattr(inputs, name) # value must be set to follow requires + if isdefined(value) and not check_requires(inputs, spec.requires): + if raise_exc: + raise RequiredInputError(inputs, name) + + return True + +def check_version(traited_spec, version=None, raise_exc=True): + """ Raises an exception on version mismatch + """ + + # no version passed on to check against + if not version: + return [] + + # check minimum version + names = traited_spec.trait_names( + min_ver=lambda t: t is not None) + \ + traited_spec.trait_names( + max_ver=lambda t: t is not None) + + # no traits defined any versions + if not names: + return [] + + version = Version(str(version)) + unavailable_traits = [] + for name in names: + value_set = isdefined(getattr(traited_spec, name)) + min_ver = traited_spec.traits()[name].min_ver + if min_ver: + min_ver = Version(str(min_ver)) + + max_ver = traited_spec.traits()[name].max_ver + if max_ver: + max_ver = Version(str(max_ver)) + + if min_ver and max_ver: + if max_ver < min_ver: + raise AssertionError( + 'Trait "%s" (%s) has incongruent version metadata ' + '(``max_ver`` is lower than ``min_ver``).' % ( + traited_spec.__class__.__name__, name)) + + if min_ver and (min_ver > version): + unavailable_traits.append(name) + if value_set and raise_exc: + raise VersionIOError(traited_spec, name, version) + if max_ver and (max_ver < version): + unavailable_traits.append(name) + if value_set and raise_exc: + raise VersionIOError(traited_spec, name, version) + + return list(set(unavailable_traits)) diff --git a/nipype/interfaces/base/tests/test_core.py b/nipype/interfaces/base/tests/test_core.py index bcbd43db28..c621446bfc 100644 --- a/nipype/interfaces/base/tests/test_core.py +++ b/nipype/interfaces/base/tests/test_core.py @@ -88,11 +88,6 @@ class DerivedInterface(nib.BaseInterface): assert 'moo' in ''.join(DerivedInterface._inputs_help()) assert DerivedInterface()._outputs() is None assert DerivedInterface().inputs.foo == nib.Undefined - with pytest.raises(ValueError): - DerivedInterface()._check_mandatory_inputs() - assert DerivedInterface(goo=1)._check_mandatory_inputs() is None - with pytest.raises(ValueError): - DerivedInterface().run() with pytest.raises(NotImplementedError): DerivedInterface(goo=1).run() @@ -170,136 +165,15 @@ def __init__(self, **inputs): assert '8562a5623562a871115eb14822ee8d02' == hashvalue -class MinVerInputSpec(nib.TraitedSpec): - foo = nib.traits.Int(desc='a random int', min_ver='0.9') - -class MaxVerInputSpec(nib.TraitedSpec): - foo = nib.traits.Int(desc='a random int', max_ver='0.7') - - -def test_input_version_1(): - class DerivedInterface1(nib.BaseInterface): - input_spec = MinVerInputSpec - - obj = DerivedInterface1() - obj._check_version_requirements(obj.inputs) - +def test_stop_on_unknown_version(): config.set('execution', 'stop_on_unknown_version', True) + ci = nib.CommandLine(command='which') with pytest.raises(ValueError) as excinfo: - obj._check_version_requirements(obj.inputs) + _ = ci.version assert "no version information" in str(excinfo.value) - config.set_default_config() - -def test_input_version_2(): - class DerivedInterface1(nib.BaseInterface): - input_spec = MinVerInputSpec - _version = '0.8' - - obj = DerivedInterface1() - obj.inputs.foo = 1 - with pytest.raises(Exception) as excinfo: - obj._check_version_requirements(obj.inputs) - assert "version 0.8 < required 0.9" in str(excinfo.value) - - -def test_input_version_3(): - class DerivedInterface1(nib.BaseInterface): - input_spec = MinVerInputSpec - _version = '0.10' - - obj = DerivedInterface1() - obj._check_version_requirements(obj.inputs) - - -def test_input_version_4(): - class DerivedInterface1(nib.BaseInterface): - input_spec = MinVerInputSpec - _version = '0.9' - - obj = DerivedInterface1() - obj.inputs.foo = 1 - obj._check_version_requirements(obj.inputs) - - -def test_input_version_5(): - class DerivedInterface2(nib.BaseInterface): - input_spec = MaxVerInputSpec - _version = '0.8' - - obj = DerivedInterface2() - obj.inputs.foo = 1 - with pytest.raises(Exception) as excinfo: - obj._check_version_requirements(obj.inputs) - assert "version 0.8 > required 0.7" in str(excinfo.value) - - -def test_input_version_6(): - class DerivedInterface1(nib.BaseInterface): - input_spec = MaxVerInputSpec - _version = '0.7' - - obj = DerivedInterface1() - obj.inputs.foo = 1 - obj._check_version_requirements(obj.inputs) - - -def test_output_version(): - class InputSpec(nib.TraitedSpec): - foo = nib.traits.Int(desc='a random int') - - class OutputSpec(nib.TraitedSpec): - foo = nib.traits.Int(desc='a random int', min_ver='0.9') - - class DerivedInterface1(nib.BaseInterface): - input_spec = InputSpec - output_spec = OutputSpec - _version = '0.10' - resource_monitor = False - - obj = DerivedInterface1() - assert obj._check_version_requirements(obj._outputs()) == [] - - class InputSpec(nib.TraitedSpec): - foo = nib.traits.Int(desc='a random int') - - class OutputSpec(nib.TraitedSpec): - foo = nib.traits.Int(desc='a random int', min_ver='0.11') - - class DerivedInterface1(nib.BaseInterface): - input_spec = InputSpec - output_spec = OutputSpec - _version = '0.10' - resource_monitor = False - - obj = DerivedInterface1() - assert obj._check_version_requirements(obj._outputs()) == ['foo'] - - class InputSpec(nib.TraitedSpec): - foo = nib.traits.Int(desc='a random int') - - class OutputSpec(nib.TraitedSpec): - foo = nib.traits.Int(desc='a random int', min_ver='0.11') - - class DerivedInterface1(nib.BaseInterface): - input_spec = InputSpec - output_spec = OutputSpec - _version = '0.10' - resource_monitor = False - - def _run_interface(self, runtime): - return runtime - - def _list_outputs(self): - return {'foo': 1} - - obj = DerivedInterface1() - with pytest.raises(KeyError): - obj.run() - - def test_Commandline(): with pytest.raises(Exception): nib.CommandLine() diff --git a/nipype/interfaces/base/tests/test_specs.py b/nipype/interfaces/base/tests/test_specs.py index bab112e96d..1ded42ee84 100644 --- a/nipype/interfaces/base/tests/test_specs.py +++ b/nipype/interfaces/base/tests/test_specs.py @@ -2,15 +2,19 @@ # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: from __future__ import print_function, unicode_literals -from future import standard_library import os import warnings - import pytest +from future import standard_library from ....utils.filemanip import split_filename from ... import base as nib from ...base import traits, Undefined +from ..specs import ( + check_mandatory_inputs, check_version, + MandatoryInputError, MutuallyExclusiveInputError, + RequiredInputError, VersionIOError +) from ....interfaces import fsl from ...utility.wrappers import Function from ....pipeline import Node @@ -134,7 +138,7 @@ class MyInterface(nib.BaseInterface): myif.inputs.foo = 1 assert myif.inputs.foo == 1 set_bar = lambda: setattr(myif.inputs, 'bar', 1) - with pytest.raises(IOError): + with pytest.raises(MutuallyExclusiveInputError): set_bar() assert myif.inputs.foo == 1 myif.inputs.kung = 2 @@ -447,7 +451,6 @@ class InputSpec(nib.TraitedSpec): class DerivedInterface(nib.BaseInterface): input_spec = InputSpec resource_monitor = False - def normalize_filenames(self): """A mock normalize_filenames for freesurfer interfaces that have one""" self.inputs.zoo = 'normalized_filename.ext' @@ -475,3 +478,148 @@ def normalize_filenames(self): assert info[0]['copy'] assert info[1]['key'] == 'zoo' assert not info[1]['copy'] + + +def test_inputs_checks(): + + class InputSpec(nib.TraitedSpec): + goo = nib.traits.Int(desc='a random int', mandatory=True) + + class DerivedInterface(nib.BaseInterface): + input_spec = InputSpec + resource_monitor = False + + assert check_mandatory_inputs( + DerivedInterface(goo=1).inputs) + with pytest.raises(MandatoryInputError): + check_mandatory_inputs( + DerivedInterface().inputs) + with pytest.raises(MandatoryInputError): + DerivedInterface().run() + + class InputSpec(nib.TraitedSpec): + goo = nib.traits.Int(desc='a random int', mandatory=True, + requires=['woo']) + woo = nib.traits.Int(desc='required by goo') + + class DerivedInterface(nib.BaseInterface): + input_spec = InputSpec + resource_monitor = False + + assert check_mandatory_inputs( + DerivedInterface(goo=1, woo=1).inputs) + with pytest.raises(RequiredInputError): + check_mandatory_inputs( + DerivedInterface(goo=1).inputs) + with pytest.raises(RequiredInputError): + DerivedInterface(goo=1).run() + + class InputSpec(nib.TraitedSpec): + goo = nib.traits.Int(desc='a random int', mandatory=True, + xor=['woo']) + woo = nib.traits.Int(desc='a random int', mandatory=True, + xor=['goo']) + + class DerivedInterface(nib.BaseInterface): + input_spec = InputSpec + resource_monitor = False + + # If either goo or woo are set, then okay! + assert check_mandatory_inputs( + DerivedInterface(goo=1).inputs) + assert check_mandatory_inputs( + DerivedInterface(woo=1).inputs) + + # None are set, raise MandatoryInputError + with pytest.raises(MutuallyExclusiveInputError): + check_mandatory_inputs( + DerivedInterface().inputs) + + # Both are set, raise MutuallyExclusiveInputError + with pytest.raises(MutuallyExclusiveInputError): + check_mandatory_inputs( + DerivedInterface(goo=1, woo=1).inputs) + with pytest.raises(MutuallyExclusiveInputError): + DerivedInterface(goo=1, woo=1).run() + + +def test_input_version(): + class MinVerInputSpec(nib.TraitedSpec): + foo = nib.traits.Int(desc='a random int', min_ver='0.5') + + + assert check_version(MinVerInputSpec(), '0.6') == [] + assert check_version(MinVerInputSpec(), '0.4') == ['foo'] + with pytest.raises(VersionIOError): + check_version(MinVerInputSpec(foo=1), '0.4') + + + class MaxVerInputSpec(nib.TraitedSpec): + foo = nib.traits.Int(desc='a random int', max_ver='0.7') + + + assert check_version(MaxVerInputSpec(), '0.6') == [] + assert check_version(MaxVerInputSpec(), '0.8') == ['foo'] + with pytest.raises(VersionIOError): + check_version(MaxVerInputSpec(foo=1), '0.8') + + + class MinMaxVerInputSpec(nib.TraitedSpec): + foo = nib.traits.Int(desc='a random int', max_ver='0.7', + min_ver='0.5') + + + assert check_version(MinMaxVerInputSpec(), '0.6') == [] + assert check_version(MinMaxVerInputSpec(), '0.4') == ['foo'] + assert check_version(MinMaxVerInputSpec(), '0.8') == ['foo'] + with pytest.raises(VersionIOError): + check_version(MinMaxVerInputSpec(foo=1), '0.8') + with pytest.raises(VersionIOError): + check_version(MinMaxVerInputSpec(foo=1), '0.4') + + + class FixedVerInputSpec(nib.TraitedSpec): + foo = nib.traits.Int(desc='a random int', max_ver='0.6.2', + min_ver='0.6.2') + + + assert check_version(FixedVerInputSpec(), '0.6.2') == [] + assert check_version(FixedVerInputSpec(), '0.6.1') == ['foo'] + assert check_version(FixedVerInputSpec(), '0.6.3') == ['foo'] + with pytest.raises(VersionIOError): + check_version(FixedVerInputSpec(foo=1), '0.6.1') + with pytest.raises(VersionIOError): + check_version(FixedVerInputSpec(foo=1), '0.6.3') + + + class IncongruentVerInputSpec(nib.TraitedSpec): + foo = nib.traits.Int(desc='a random int', max_ver='0.5', + min_ver='0.7') + + with pytest.raises(AssertionError): + check_version(IncongruentVerInputSpec(), '0.6') + with pytest.raises(AssertionError): + check_version(IncongruentVerInputSpec(foo=1), '0.6') + + + class InputSpec(nib.TraitedSpec): + foo = nib.traits.Int(desc='a random int') + + class OutputSpec(nib.TraitedSpec): + foo = nib.traits.Int(desc='a random int', min_ver='0.11') + + class DerivedInterface1(nib.BaseInterface): + input_spec = InputSpec + output_spec = OutputSpec + _version = '0.10' + resource_monitor = False + + def _run_interface(self, runtime): + return runtime + + def _list_outputs(self): + return {'foo': 1} + + obj = DerivedInterface1() + with pytest.raises(KeyError): + obj.run() diff --git a/nipype/interfaces/freesurfer/preprocess.py b/nipype/interfaces/freesurfer/preprocess.py index 2941968f85..b25c7fc6f6 100644 --- a/nipype/interfaces/freesurfer/preprocess.py +++ b/nipype/interfaces/freesurfer/preprocess.py @@ -21,6 +21,7 @@ from ..base import (TraitedSpec, File, traits, Directory, InputMultiPath, OutputMultiPath, CommandLine, CommandLineInputSpec, isdefined) +from ..base.specs import check_mandatory_inputs from .base import (FSCommand, FSTraitedSpec, FSTraitedSpecOpenMP, FSCommandOpenMP, Info) from .utils import copy2subjdir @@ -634,7 +635,12 @@ def _get_filelist(self, outdir): def cmdline(self): """ `command` plus any arguments (args) validates arguments and generates command line""" - self._check_mandatory_inputs() + if not check_mandatory_inputs(self.inputs, raise_exc=False): + iflogger.warning( + 'Some inputs are not valid. Please make sure all mandatory ' + 'inputs, required inputs and mutually-exclusive inputs are ' + 'set or in a sane state.') + outdir = self._get_outdir() cmd = [] if not os.path.exists(outdir): diff --git a/nipype/interfaces/freesurfer/tests/test_preprocess.py b/nipype/interfaces/freesurfer/tests/test_preprocess.py index f9fc09515a..64dd3b8379 100644 --- a/nipype/interfaces/freesurfer/tests/test_preprocess.py +++ b/nipype/interfaces/freesurfer/tests/test_preprocess.py @@ -7,24 +7,26 @@ import pytest from nipype.testing.fixtures import create_files_in_directory -from nipype.interfaces import freesurfer +from nipype.interfaces import freesurfer as fs from nipype.interfaces.freesurfer import Info from nipype import LooseVersion +from nipype.utils import errors as nue @pytest.mark.skipif( - freesurfer.no_freesurfer(), reason="freesurfer is not installed") + fs.no_freesurfer(), reason="freesurfer is not installed") def test_robustregister(create_files_in_directory): filelist, outdir = create_files_in_directory - reg = freesurfer.RobustRegister() + reg = fs.RobustRegister() cwd = os.getcwd() # make sure command gets called assert reg.cmd == 'mri_robust_register' + assert reg.cmdline is None # test raising error with mandatory args absent - with pytest.raises(ValueError): + with pytest.raises(nue.MandatoryInputError): reg.run() # .inputs based parameters setting @@ -36,7 +38,7 @@ def test_robustregister(create_files_in_directory): (cwd, filelist[0][:-4], filelist[0], filelist[1])) # constructor based parameter setting - reg2 = freesurfer.RobustRegister( + reg2 = fs.RobustRegister( source_file=filelist[0], target_file=filelist[1], outlier_sens=3.0, @@ -49,17 +51,18 @@ def test_robustregister(create_files_in_directory): @pytest.mark.skipif( - freesurfer.no_freesurfer(), reason="freesurfer is not installed") + fs.no_freesurfer(), reason="freesurfer is not installed") def test_fitmsparams(create_files_in_directory): filelist, outdir = create_files_in_directory - fit = freesurfer.FitMSParams() + fit = fs.FitMSParams() # make sure command gets called assert fit.cmd == 'mri_ms_fitparms' + assert fit.cmdline is None # test raising error with mandatory args absent - with pytest.raises(ValueError): + with pytest.raises(nue.MandatoryInputError): fit.run() # .inputs based parameters setting @@ -69,7 +72,7 @@ def test_fitmsparams(create_files_in_directory): filelist[1], outdir) # constructor based parameter setting - fit2 = freesurfer.FitMSParams( + fit2 = fs.FitMSParams( in_files=filelist, te_list=[1.5, 3.5], flip_list=[20, 30], @@ -80,17 +83,18 @@ def test_fitmsparams(create_files_in_directory): @pytest.mark.skipif( - freesurfer.no_freesurfer(), reason="freesurfer is not installed") + fs.no_freesurfer(), reason="freesurfer is not installed") def test_synthesizeflash(create_files_in_directory): filelist, outdir = create_files_in_directory - syn = freesurfer.SynthesizeFLASH() + syn = fs.SynthesizeFLASH() # make sure command gets called assert syn.cmd == 'mri_synthesize' + assert syn.cmdline is None # test raising error with mandatory args absent - with pytest.raises(ValueError): + with pytest.raises(nue.MandatoryInputError): syn.run() # .inputs based parameters setting @@ -105,7 +109,7 @@ def test_synthesizeflash(create_files_in_directory): os.path.join(outdir, 'synth-flash_30.mgz'))) # constructor based parameters setting - syn2 = freesurfer.SynthesizeFLASH( + syn2 = fs.SynthesizeFLASH( t1_image=filelist[0], pd_image=filelist[1], flip_angle=20, te=5, tr=25) assert syn2.cmdline == ('mri_synthesize 25.00 20.00 5.000 %s %s %s' % (filelist[0], filelist[1], @@ -113,17 +117,18 @@ def test_synthesizeflash(create_files_in_directory): @pytest.mark.skipif( - freesurfer.no_freesurfer(), reason="freesurfer is not installed") + fs.no_freesurfer(), reason="freesurfer is not installed") def test_mandatory_outvol(create_files_in_directory): filelist, outdir = create_files_in_directory - mni = freesurfer.MNIBiasCorrection() + mni = fs.MNIBiasCorrection() # make sure command gets called assert mni.cmd == "mri_nu_correct.mni" + assert mni.cmdline is None # test raising error with mandatory args absent - with pytest.raises(ValueError): - mni.cmdline + with pytest.raises(nue.MandatoryInputError): + mni.run() # test with minimal args mni.inputs.in_file = filelist[0] @@ -141,7 +146,7 @@ def test_mandatory_outvol(create_files_in_directory): 'mri_nu_correct.mni --i %s --n 4 --o new_corrected_file.mgz' % (filelist[0])) # constructor based tests - mni2 = freesurfer.MNIBiasCorrection( + mni2 = fs.MNIBiasCorrection( in_file=filelist[0], out_file='bias_corrected_output', iterations=2) assert mni2.cmdline == ( 'mri_nu_correct.mni --i %s --n 2 --o bias_corrected_output' % @@ -149,17 +154,23 @@ def test_mandatory_outvol(create_files_in_directory): @pytest.mark.skipif( - freesurfer.no_freesurfer(), reason="freesurfer is not installed") -def test_bbregister(create_files_in_directory): + fs.no_freesurfer(), reason="freesurfer is not installed") +def test_bbregister(caplog, create_files_in_directory): filelist, outdir = create_files_in_directory - bbr = freesurfer.BBRegister() + bbr = fs.BBRegister() # make sure command gets called assert bbr.cmd == "bbregister" + # cmdline issues a warning: mandatory inputs missing + assert bbr.cmdline is None + + captured = caplog.text + assert 'Command line could not be generated' in captured + # test raising error with mandatory args absent - with pytest.raises(ValueError): - bbr.cmdline + with pytest.raises(nue.MandatoryInputError): + bbr.run() bbr.inputs.subject_id = 'fsaverage' bbr.inputs.source_file = filelist[0] @@ -167,10 +178,14 @@ def test_bbregister(create_files_in_directory): # Check that 'init' is mandatory in FS < 6, but not in 6+ if Info.looseversion() < LooseVersion("6.0.0"): - with pytest.raises(ValueError): - bbr.cmdline + assert bbr.cmdline is None + captured = caplog.text + assert 'Command line could not be generated' in captured + + with pytest.raises(nue.VersionIOError): + bbr.run() else: - bbr.cmdline + assert bbr.cmdline is not None bbr.inputs.init = 'fsl' @@ -187,5 +202,5 @@ def test_bbregister(create_files_in_directory): def test_FSVersion(): """Check that FSVersion is a string that can be compared with LooseVersion """ - assert isinstance(freesurfer.preprocess.FSVersion, str) - assert LooseVersion(freesurfer.preprocess.FSVersion) >= LooseVersion("0") + assert isinstance(fs.preprocess.FSVersion, str) + assert LooseVersion(fs.preprocess.FSVersion) >= LooseVersion("0") diff --git a/nipype/interfaces/fsl/tests/test_preprocess.py b/nipype/interfaces/fsl/tests/test_preprocess.py index 4b387201cf..578fb745a2 100644 --- a/nipype/interfaces/fsl/tests/test_preprocess.py +++ b/nipype/interfaces/fsl/tests/test_preprocess.py @@ -15,6 +15,7 @@ from nipype.interfaces.fsl import Info from nipype.interfaces.base import File, TraitError, Undefined, isdefined from nipype.interfaces.fsl import no_fsl +from nipype.utils import errors as nue def fsl_name(obj, fname): @@ -39,7 +40,7 @@ def test_bet(setup_infile): assert better.cmd == 'bet' # Test raising error with mandatory args absent - with pytest.raises(ValueError): + with pytest.raises(nue.MandatoryInputError): better.run() # Test generated outfile name @@ -195,7 +196,7 @@ def setup_flirt(tmpdir): @pytest.mark.skipif(no_fsl(), reason="fsl is not installed") -def test_flirt(setup_flirt): +def test_flirt(caplog, setup_flirt): # setup tmpdir, infile, reffile = setup_flirt @@ -230,12 +231,24 @@ def test_flirt(setup_flirt): flirter = fsl.FLIRT() # infile not specified - with pytest.raises(ValueError): - flirter.cmdline + assert flirter.cmdline is None + captured = caplog.text + assert 'Command line could not be generated' in captured + + # interface should raise error with mandatory inputs unset + with pytest.raises(nue.MandatoryInputError): + flirter.run() + flirter.inputs.in_file = infile # reference not specified - with pytest.raises(ValueError): - flirter.cmdline + assert flirter.cmdline is None + captured = caplog.text + assert 'Command line could not be generated' in captured + + # interface should raise error with reference still unset + with pytest.raises(nue.MandatoryInputError): + flirter.run() + flirter.inputs.reference = reffile # Generate outfile and outmatrix @@ -380,10 +393,10 @@ def test_mcflirt_opt(setup_flirt): def test_mcflirt_noinput(): # Test error is raised when missing required args fnt = fsl.MCFLIRT() - with pytest.raises(ValueError) as excinfo: + with pytest.raises(nue.MandatoryInputError) as excinfo: fnt.run() assert str(excinfo.value).startswith( - "MCFLIRT requires a value for input 'in_file'") + 'Interface "MCFLIRT" requires a value for input in_file.') # test fnirt @@ -441,9 +454,9 @@ def test_fnirt(setup_flirt): iout) assert fnirt.cmdline == cmd - # Test ValueError is raised when missing mandatory args + # Test nue.MandatoryInputError is raised when missing mandatory args fnirt = fsl.FNIRT() - with pytest.raises(ValueError): + with pytest.raises(nue.MandatoryInputError): fnirt.run() fnirt.inputs.in_file = infile fnirt.inputs.ref_file = reffile diff --git a/nipype/interfaces/fsl/utils.py b/nipype/interfaces/fsl/utils.py index f4ef73c0e9..b85101f867 100644 --- a/nipype/interfaces/fsl/utils.py +++ b/nipype/interfaces/fsl/utils.py @@ -238,16 +238,29 @@ class Smooth(FSLCommand): >>> sm.cmdline # doctest: +ELLIPSIS 'fslmaths functional2.nii -kernel gauss 3.397 -fmean functional2_smooth.nii.gz' - One of sigma or fwhm must be set: + One of sigma or fwhm must be set. Accessing the ``cmdline`` property + will return ``None`` and issue a warning: >>> from nipype.interfaces.fsl import Smooth >>> sm = Smooth() >>> sm.inputs.output_type = 'NIFTI_GZ' >>> sm.inputs.in_file = 'functional2.nii' - >>> sm.cmdline #doctest: +ELLIPSIS + >>> sm.cmdline is None # doctest: +ELLIPSIS + True + + The warning is: :: + 181125-08:12:09,489 nipype.interface WARNING: + Some inputs are not valid. Please make sure all mandatory + inputs, required inputs and mutually-exclusive inputs are + set or in a sane state. + + Attempting to run the interface without the necessary inputs will + lead to an error (in this case, a ``MutuallyExclusiveInputError``): + + >>> sm.run() # doctest: +ELLIPSIS Traceback (most recent call last): - ... - ValueError: Smooth requires a value for one of the inputs ... + ... + MutuallyExclusiveInputError: Interface ... """ diff --git a/nipype/pytest.ini b/nipype/pytest.ini index 70f12b64aa..5f22555598 100644 --- a/nipype/pytest.ini +++ b/nipype/pytest.ini @@ -1,6 +1,6 @@ [pytest] norecursedirs = .git build dist doc nipype/external tools examples src addopts = --doctest-modules -n auto -doctest_optionflags = ALLOW_UNICODE NORMALIZE_WHITESPACE +doctest_optionflags = ALLOW_UNICODE NORMALIZE_WHITESPACE IGNORE_EXCEPTION_DETAIL env = PYTHONHASHSEED=0 diff --git a/nipype/utils/errors.py b/nipype/utils/errors.py new file mode 100644 index 0000000000..1fff7404c6 --- /dev/null +++ b/nipype/utils/errors.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +"""Errors and exceptions +""" +from __future__ import (print_function, division, unicode_literals, + absolute_import) + + +class MandatoryInputError(ValueError): + """Raised when one input with the ``mandatory`` metadata set to ``True`` is + not defined.""" + def __init__(self, inputspec, name): + classname = _classname_from_spec(inputspec) + msg = ( + 'Interface "{classname}" requires a value for input {name}. ' + 'For a list of required inputs, see {classname}.help().').format( + classname=classname, name=name) + super(MandatoryInputError, self).__init__(msg) + +class MutuallyExclusiveInputError(ValueError): + """Raised when none or more than one mutually-exclusive inputs are set.""" + def __init__(self, inputspec, name, values_defined=None, name_other=None): + classname = _classname_from_spec(inputspec) + + if values_defined is not None: + xor = inputspec.traits()[name].xor or [] + xor = set(list(xor) if isinstance(xor, (list, tuple)) + else [xor]) + xor.add(name) + msg = ('Interface "{classname}" has mutually-exclusive inputs ' + '(processing "{name}", with value={value}). ' + 'Exactly one of ({xor}) should be set, but {n:d} were set. ' + 'For a list of mutually-exclusive inputs, see ' + '{classname}.help().').format(classname=classname, + xor='|'.join(xor), + n=values_defined, + name=name, + value=getattr(inputspec, name)) + + else: + msg = ('Interface "{classname}" has mutually-exclusive inputs. ' + 'Input "{name}" is mutually exclusive with input ' + '"{name_other}", which is already set').format( + classname=classname, name=name, name_other=name_other) + super(MutuallyExclusiveInputError, self).__init__(msg) + +class RequiredInputError(ValueError): + """Raised when one input requires some other and those or some of + those are ``Undefined``.""" + def __init__(self, inputspec, name): + classname = _classname_from_spec(inputspec) + requires = inputspec.traits()[name].requires + + msg = ('Interface "{classname}" requires a value for input {name} ' + 'because one of ({requires}) is set. For a list of required ' + 'inputs, see {classname}.help().').format( + classname=classname, name=name, + requires=', '.join(requires)) + super(RequiredInputError, self).__init__(msg) + +class VersionIOError(ValueError): + """Raised when one input with the ``mandatory`` metadata set to ``True`` is + not defined.""" + def __init__(self, spec, name, version): + classname = _classname_from_spec(spec) + max_ver = spec.traits()[name].max_ver + min_ver = spec.traits()[name].min_ver + + msg = ('Interface "{classname}" has version requirements for ' + '{name}, but version {version} was found. ').format( + classname=classname, name=name, version=version) + + if min_ver: + msg += 'Minimum version is %s. ' % min_ver + if max_ver: + msg += 'Maximum version is %s. ' % max_ver + + super(VersionIOError, self).__init__(msg) + +def _classname_from_spec(spec): + classname = spec.__class__.__name__ + + kind = 'Output' if 'Output' in classname else 'Input' + # General pattern is that spec ends in KindSpec + if classname.endswith(kind + 'Spec') and classname != (kind + 'Spec'): + classname = classname[:-len(kind + 'Spec')] + + # Catch some special cases such as ANTS + if classname.endswith(kind) and classname != kind: + classname = classname[:-len(kind)] + + return classname