Skip to content

[MAINT] Outsource checks of inputs from interface #2799

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7cb2038
[MAINT] Outsource checks of inputs from interface
oesteban Nov 25, 2018
a15c6fb
remove forgotten print
oesteban Nov 25, 2018
a3600d9
fix mutually-exclusive check
oesteban Nov 25, 2018
c095d65
fix bug trying to append a spec.xor that is None
oesteban Nov 25, 2018
61d553c
fix precedence
oesteban Nov 25, 2018
1ade70c
outsource ``_check_version_requirements``
oesteban Nov 25, 2018
94c0823
fix multiple xor check
oesteban Nov 25, 2018
fddec69
fix check_requires for non-mandatory requires
oesteban Nov 25, 2018
34cb9c9
fix massaging xored inputs
oesteban Nov 25, 2018
523475e
Merge remote-tracking branch 'upstream/master' into maint/outsource-c…
oesteban Nov 25, 2018
41a48bd
fix test_extra_Registration
oesteban Nov 25, 2018
8f8ceb1
fixed fsl.utils doctests
oesteban Nov 25, 2018
d0c6ab8
cmdline returns ``None`` instead of raising if error on inputs
oesteban Nov 25, 2018
1b948cf
fix one more use-case
oesteban Nov 25, 2018
1a558a4
fixed tests
oesteban Nov 25, 2018
b08b5ad
fix errors checking xor
oesteban Nov 25, 2018
44a724d
fix fs.preprocess test
oesteban Nov 25, 2018
f7b5650
fix fsl.tests.test_preprocess
oesteban Nov 26, 2018
294850b
fix fsl.tests.test_preprocess
oesteban Nov 26, 2018
13ac0b9
Merge branch 'maint/outsource-checks' of github.com:oesteban/nipype i…
oesteban Nov 26, 2018
89da56a
install caplog fixture
oesteban Nov 26, 2018
fff59eb
pin pytest>=3.4 for reliable caplog fixture
oesteban Nov 26, 2018
ce41579
Merge branch 'master' into maint/outsource-checks
oesteban Nov 27, 2018
7967b08
Merge remote-tracking branch 'upstream/master' into maint/outsource-c…
oesteban Nov 28, 2018
2d5fd05
Merge branch 'maint/outsource-checks' of github.com:oesteban/nipype i…
oesteban Nov 28, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions nipype/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_nipype_gitversion():
SCIPY_MIN_VERSION = '0.14'
TRAITS_MIN_VERSION = '4.6'
DATEUTIL_MIN_VERSION = '2.2'
PYTEST_MIN_VERSION = '3.0'
PYTEST_MIN_VERSION = '3.4'
FUTURE_MIN_VERSION = '0.16.0'
SIMPLEJSON_MIN_VERSION = '3.8.0'
PROV_VERSION = '1.5.2'
Expand Down Expand Up @@ -159,7 +159,12 @@ def get_nipype_gitversion():
if sys.version_info <= (3, 4):
REQUIRES.append('configparser')

TESTS_REQUIRES = ['pytest-cov', 'codecov', 'pytest-env', 'coverage<5']
TESTS_REQUIRES = [
'pytest-cov',
'pytest-env',
'codecov',
'coverage<5',
]

EXTRA_REQUIRES = {
'doc': ['Sphinx>=1.4', 'numpydoc', 'matplotlib', 'pydotplus', 'pydot>=1.2.3'],
Expand Down
7 changes: 4 additions & 3 deletions nipype/interfaces/ants/tests/test_extra_Registration.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
from __future__ import unicode_literals
from nipype.interfaces.ants import registration
import os
import pytest
from nipype.interfaces.ants import registration
from nipype.utils.errors import MandatoryInputError


def test_ants_mand(tmpdir):
Expand All @@ -17,6 +18,6 @@ def test_ants_mand(tmpdir):
ants.inputs.fixed_image = [os.path.join(datadir, 'T1.nii')]
ants.inputs.metric = ['MI']

with pytest.raises(ValueError) as er:
with pytest.raises(MandatoryInputError) as er:
ants.run()
assert "ANTS requires a value for input 'radius'" in str(er.value)
assert 'Interface "ANTS" requires a value for input radius.' in str(er.value)
108 changes: 14 additions & 94 deletions nipype/interfaces/base/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import simplejson as json
from dateutil.parser import parse as parseutc

from ... import config, logging, LooseVersion
from ... import config, logging
from ...utils.provenance import write_provenance
from ...utils.misc import trim, str2bool, rgetcwd
from ...utils.filemanip import (FileNotFoundError, split_filename,
Expand All @@ -41,7 +41,7 @@
from .traits_extension import traits, isdefined, TraitError
from .specs import (BaseInterfaceInputSpec, CommandLineInputSpec,
StdOutCommandLineInputSpec, MpiCommandLineInputSpec,
get_filecopy_info)
get_filecopy_info, check_mandatory_inputs, check_version)
from .support import (Bunch, InterfaceResult, NipypeInterfaceError)

from future import standard_library
Expand Down Expand Up @@ -335,93 +335,6 @@ def _outputs(self):

return outputs

def _check_requires(self, spec, name, value):
""" check if required inputs are satisfied
"""
if spec.requires:
values = [
not isdefined(getattr(self.inputs, field))
for field in spec.requires
]
if any(values) and isdefined(value):
msg = ("%s requires a value for input '%s' because one of %s "
"is set. For a list of required inputs, see %s.help()" %
(self.__class__.__name__, name,
', '.join(spec.requires), self.__class__.__name__))
raise ValueError(msg)

def _check_xor(self, spec, name, value):
""" check if mutually exclusive inputs are satisfied
"""
if spec.xor:
values = [
isdefined(getattr(self.inputs, field)) for field in spec.xor
]
if not any(values) and not isdefined(value):
msg = ("%s requires a value for one of the inputs '%s'. "
"For a list of required inputs, see %s.help()" %
(self.__class__.__name__, ', '.join(spec.xor),
self.__class__.__name__))
raise ValueError(msg)

def _check_mandatory_inputs(self):
""" Raises an exception if a mandatory input is Undefined
"""
for name, spec in list(self.inputs.traits(mandatory=True).items()):
value = getattr(self.inputs, name)
self._check_xor(spec, name, value)
if not isdefined(value) and spec.xor is None:
msg = ("%s requires a value for input '%s'. "
"For a list of required inputs, see %s.help()" %
(self.__class__.__name__, name,
self.__class__.__name__))
raise ValueError(msg)
if isdefined(value):
self._check_requires(spec, name, value)
for name, spec in list(
self.inputs.traits(mandatory=None, transient=None).items()):
self._check_requires(spec, name, getattr(self.inputs, name))

def _check_version_requirements(self, trait_object, raise_exception=True):
""" Raises an exception on version mismatch
"""
unavailable_traits = []
# check minimum version
check = dict(min_ver=lambda t: t is not None)
names = trait_object.trait_names(**check)

if names and self.version:
version = LooseVersion(str(self.version))
for name in names:
min_ver = LooseVersion(
str(trait_object.traits()[name].min_ver))
if min_ver > version:
unavailable_traits.append(name)
if not isdefined(getattr(trait_object, name)):
continue
if raise_exception:
raise Exception(
'Trait %s (%s) (version %s < required %s)' %
(name, self.__class__.__name__, version, min_ver))

# check maximum version
check = dict(max_ver=lambda t: t is not None)
names = trait_object.trait_names(**check)
if names and self.version:
version = LooseVersion(str(self.version))
for name in names:
max_ver = LooseVersion(
str(trait_object.traits()[name].max_ver))
if max_ver < version:
unavailable_traits.append(name)
if not isdefined(getattr(trait_object, name)):
continue
if raise_exception:
raise Exception(
'Trait %s (%s) (version %s > required %s)' %
(name, self.__class__.__name__, version, max_ver))
return unavailable_traits

def _run_interface(self, runtime):
""" Core function that executes interface
"""
Expand Down Expand Up @@ -466,8 +379,8 @@ def run(self, cwd=None, ignore_exception=None, **inputs):

enable_rm = config.resource_monitor and self.resource_monitor
self.inputs.trait_set(**inputs)
self._check_mandatory_inputs()
self._check_version_requirements(self.inputs)
check_mandatory_inputs(self.inputs)
check_version(self.inputs, version=self.version)
interface = self.__class__
self._duecredit_cite()

Expand Down Expand Up @@ -593,8 +506,8 @@ def aggregate_outputs(self, runtime=None, needed_outputs=None):
if predicted_outputs:
_unavailable_outputs = []
if outputs:
_unavailable_outputs = \
self._check_version_requirements(self._outputs())
_unavailable_outputs = check_version(
self._outputs(), self.version)
for key, val in list(predicted_outputs.items()):
if needed_outputs and key not in needed_outputs:
continue
Expand Down Expand Up @@ -810,7 +723,14 @@ def cmd(self):
def cmdline(self):
""" `command` plus any arguments (args)
validates arguments and generates command line"""
self._check_mandatory_inputs()
if not check_mandatory_inputs(self.inputs, raise_exc=False):
iflogger.warning(
'Command line could not be generated because some inputs '
'are not valid. Please make sure all mandatory inputs, '
'required inputs and mutually-exclusive inputs are set '
'or in a sane state.')
return None

allargs = [self._cmd_prefix + self.cmd] + self._parse_inputs()
return ' '.join(allargs)

Expand Down
122 changes: 119 additions & 3 deletions nipype/interfaces/base/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from packaging.version import Version

from ...utils.filemanip import md5, hash_infile, hash_timestamp, to_str
from ...utils.errors import (
MandatoryInputError, MutuallyExclusiveInputError, RequiredInputError,
VersionIOError)
from .traits_extension import (
traits,
Undefined,
Expand Down Expand Up @@ -115,9 +118,8 @@ def _xor_warn(self, obj, name, old, new):
trait_change_notify=False, **{
'%s' % name: Undefined
})
msg = ('Input "%s" is mutually exclusive with input "%s", '
'which is already set') % (name, trait_name)
raise IOError(msg)
raise MutuallyExclusiveInputError(
self, name, name_other=trait_name)

def _deprecated_warn(self, obj, name, old, new):
"""Checks if a user assigns a value to a deprecated trait
Expand Down Expand Up @@ -394,3 +396,117 @@ def get_filecopy_info(cls):
for name, spec in sorted(inputs.traits(**metadata).items()):
info.append(dict(key=name, copy=spec.copyfile))
return info

def check_requires(inputs, requires):
"""check if required inputs are satisfied
"""
if not requires:
return True

# Check value and all required inputs' values defined
values = [isdefined(getattr(inputs, field))
for field in requires]
return all(values)

def check_xor(inputs, name, xor):
""" check if mutually exclusive inputs are satisfied
"""
if len(xor) == 0:
return True

values = [isdefined(getattr(inputs, name))]
values += [any([isdefined(getattr(inputs, field))
for field in xor])]
return sum(values)

def check_mandatory_inputs(inputs, raise_exc=True):
""" Raises an exception if a mandatory input is Undefined
"""
# Check mandatory, not xor-ed inputs.
for name, spec in list(inputs.traits(mandatory=True).items()):
value = getattr(inputs, name)
# Mandatory field is defined, check xor'ed inputs
xor = spec.xor or []
has_xor = bool(xor)
has_value = isdefined(value)

# Simplest case: no xor metadata and not defined
if not has_xor and not has_value:
if raise_exc:
raise MandatoryInputError(inputs, name)
return False

xor = set(list(xor) if isinstance(xor, (list, tuple))
else [xor])
xor.discard(name)
xor = list(xor)
cxor = check_xor(inputs, name, xor)
if cxor != 1:
if raise_exc:
raise MutuallyExclusiveInputError(
inputs, name, values_defined=cxor)
return False

# Check whether mandatory inputs require others
if has_value and not check_requires(inputs, spec.requires):
if raise_exc:
raise RequiredInputError(inputs, name)
return False

# Check requirements of non-mandatory inputs
for name, spec in list(
inputs.traits(mandatory=None, transient=None).items()):
value = getattr(inputs, name) # value must be set to follow requires
if isdefined(value) and not check_requires(inputs, spec.requires):
if raise_exc:
raise RequiredInputError(inputs, name)

return True

def check_version(traited_spec, version=None, raise_exc=True):
""" Raises an exception on version mismatch
"""

# no version passed on to check against
if not version:
return []

# check minimum version
names = traited_spec.trait_names(
min_ver=lambda t: t is not None) + \
traited_spec.trait_names(
max_ver=lambda t: t is not None)

# no traits defined any versions
if not names:
return []

version = Version(str(version))
unavailable_traits = []
for name in names:
value_set = isdefined(getattr(traited_spec, name))
min_ver = traited_spec.traits()[name].min_ver
if min_ver:
min_ver = Version(str(min_ver))

max_ver = traited_spec.traits()[name].max_ver
if max_ver:
max_ver = Version(str(max_ver))

if min_ver and max_ver:
if max_ver < min_ver:
raise AssertionError(
'Trait "%s" (%s) has incongruent version metadata '
'(``max_ver`` is lower than ``min_ver``).' % (
traited_spec.__class__.__name__, name))

if min_ver and (min_ver > version):
unavailable_traits.append(name)
if value_set and raise_exc:
raise VersionIOError(traited_spec, name, version)
if max_ver and (max_ver < version):
unavailable_traits.append(name)
if value_set and raise_exc:
raise VersionIOError(traited_spec, name, version)

return list(set(unavailable_traits))
Loading