|
| 1 | +import pytest |
| 2 | +from collections import namedtuple |
| 3 | +from ...base import traits, TraitedSpec, BaseInterfaceInputSpec |
| 4 | +from ..base import (convert_to_traits_type, create_interface_specs, |
| 5 | + dipy_to_nipype_interface, DipyBaseInterface, no_dipy) |
| 6 | + |
| 7 | + |
| 8 | +def test_convert_to_traits_type(): |
| 9 | + Params = namedtuple("Params", "traits_type is_file") |
| 10 | + Res = namedtuple("Res", "traits_type is_mandatory") |
| 11 | + l_entries = [Params('variable string', False), |
| 12 | + Params('variable int', False), |
| 13 | + Params('variable float', False), |
| 14 | + Params('variable bool', False), |
| 15 | + Params('variable complex', False), |
| 16 | + Params('variable int, optional', False), |
| 17 | + Params('variable string, optional', False), |
| 18 | + Params('variable float, optional', False), |
| 19 | + Params('variable bool, optional', False), |
| 20 | + Params('variable complex, optional', False), |
| 21 | + Params('string', False), Params('int', False), |
| 22 | + Params('string', True), Params('float', False), |
| 23 | + Params('bool', False), Params('complex', False), |
| 24 | + Params('string, optional', False), |
| 25 | + Params('int, optional', False), |
| 26 | + Params('string, optional', True), |
| 27 | + Params('float, optional', False), |
| 28 | + Params('bool, optional', False), |
| 29 | + Params('complex, optional', False), |
| 30 | + ] |
| 31 | + l_expected = [Res(traits.ListStr, True), Res(traits.ListInt, True), |
| 32 | + Res(traits.ListFloat, True), Res(traits.ListBool, True), |
| 33 | + Res(traits.ListComplex, True), Res(traits.ListInt, False), |
| 34 | + Res(traits.ListStr, False), Res(traits.ListFloat, False), |
| 35 | + Res(traits.ListBool, False), Res(traits.ListComplex, False), |
| 36 | + Res(traits.Str, True), Res(traits.Int, True), |
| 37 | + Res(traits.File, True), Res(traits.Float, True), |
| 38 | + Res(traits.Bool, True), Res(traits.Complex, True), |
| 39 | + Res(traits.Str, False), Res(traits.Int, False), |
| 40 | + Res(traits.File, False), Res(traits.Float, False), |
| 41 | + Res(traits.Bool, False), Res(traits.Complex, False), |
| 42 | + ] |
| 43 | + |
| 44 | + for entry, res in zip(l_entries, l_expected): |
| 45 | + traits_type, is_mandatory = convert_to_traits_type(entry.traits_type, |
| 46 | + entry.is_file) |
| 47 | + assert traits_type == res.traits_type |
| 48 | + assert is_mandatory == res.is_mandatory |
| 49 | + |
| 50 | + with pytest.raises(IOError): |
| 51 | + convert_to_traits_type("file, optional") |
| 52 | + |
| 53 | + |
| 54 | +def test_create_interface_specs(): |
| 55 | + new_interface = create_interface_specs("MyInterface") |
| 56 | + |
| 57 | + assert new_interface.__base__ == TraitedSpec |
| 58 | + assert isinstance(new_interface(), TraitedSpec) |
| 59 | + assert new_interface.__name__ == "MyInterface" |
| 60 | + assert not new_interface().get() |
| 61 | + |
| 62 | + new_interface = create_interface_specs("MyInterface", |
| 63 | + BaseClass=BaseInterfaceInputSpec) |
| 64 | + assert new_interface.__base__ == BaseInterfaceInputSpec |
| 65 | + assert isinstance(new_interface(), BaseInterfaceInputSpec) |
| 66 | + assert new_interface.__name__ == "MyInterface" |
| 67 | + assert not new_interface().get() |
| 68 | + |
| 69 | + params = [("params1", "string", ["my description"]), ("params2_files", "string", ["my description @"]), |
| 70 | + ("params3", "int, optional", ["useful option"]), ("out_params", "string", ["my out description"])] |
| 71 | + |
| 72 | + new_interface = create_interface_specs("MyInterface", params=params, |
| 73 | + BaseClass=BaseInterfaceInputSpec) |
| 74 | + |
| 75 | + assert new_interface.__base__ == BaseInterfaceInputSpec |
| 76 | + assert isinstance(new_interface(), BaseInterfaceInputSpec) |
| 77 | + assert new_interface.__name__ == "MyInterface" |
| 78 | + current_params = new_interface().get() |
| 79 | + assert len(current_params) == 4 |
| 80 | + assert 'params1' in current_params.keys() |
| 81 | + assert 'params2_files' in current_params.keys() |
| 82 | + assert 'params3' in current_params.keys() |
| 83 | + assert 'out_params' in current_params.keys() |
| 84 | + |
| 85 | + |
| 86 | +@pytest.mark.skipif(no_dipy(), reason="DIPY is not installed") |
| 87 | +def test_dipy_to_nipype_interface(): |
| 88 | + from dipy.workflows.workflow import Workflow |
| 89 | + |
| 90 | + class DummyWorkflow(Workflow): |
| 91 | + |
| 92 | + @classmethod |
| 93 | + def get_short_name(cls): |
| 94 | + return 'dwf1' |
| 95 | + |
| 96 | + def run(self, in_files, param1=1, out_dir='', out_ref='out1.txt'): |
| 97 | + """Workflow used to test basic workflows. |
| 98 | +
|
| 99 | + Parameters |
| 100 | + ---------- |
| 101 | + in_files : string |
| 102 | + fake input string param |
| 103 | + param1 : int, optional |
| 104 | + fake positional param (default 1) |
| 105 | + out_dir : string, optional |
| 106 | + fake output directory (default '') |
| 107 | + out_ref : string, optional |
| 108 | + fake out file (default out1.txt) |
| 109 | +
|
| 110 | + References |
| 111 | + ----------- |
| 112 | + dummy references |
| 113 | +
|
| 114 | + """ |
| 115 | + return param1 |
| 116 | + |
| 117 | + new_specs = dipy_to_nipype_interface("MyModelSpec", DummyWorkflow) |
| 118 | + assert new_specs.__base__ == DipyBaseInterface |
| 119 | + assert isinstance(new_specs(), DipyBaseInterface) |
| 120 | + assert new_specs.__name__ == "MyModelSpec" |
| 121 | + assert hasattr(new_specs, 'input_spec') |
| 122 | + assert new_specs().input_spec.__base__ == BaseInterfaceInputSpec |
| 123 | + assert hasattr(new_specs, 'output_spec') |
| 124 | + assert new_specs().output_spec.__base__ == TraitedSpec |
| 125 | + assert hasattr(new_specs, '_run_interface') |
| 126 | + assert hasattr(new_specs, '_list_outputs') |
| 127 | + params_in = new_specs().inputs.get() |
| 128 | + params_out = new_specs()._outputs().get() |
| 129 | + assert len(params_in) == 4 |
| 130 | + assert 'in_files' in params_in.keys() |
| 131 | + assert 'param1' in params_in.keys() |
| 132 | + assert 'out_dir' in params_out.keys() |
| 133 | + assert 'out_ref' in params_out.keys() |
| 134 | + |
| 135 | + with pytest.raises(ValueError): |
| 136 | + new_specs().run() |
| 137 | + |
| 138 | + |
| 139 | +if __name__ == "__main__": |
| 140 | + test_convert_to_traits_type() |
| 141 | + test_create_interface_specs() |
| 142 | + test_dipy_to_nipype_interface() |
0 commit comments