diff --git a/nipype/pipeline/engine/utils.py b/nipype/pipeline/engine/utils.py index a5245dda48..5005eaa332 100644 --- a/nipype/pipeline/engine/utils.py +++ b/nipype/pipeline/engine/utils.py @@ -2,8 +2,7 @@ # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- # vi: set ft=python sts=4 ts=4 sw=4 et: """Utility routines for workflow graphs""" -from __future__ import (print_function, division, unicode_literals, - absolute_import) +from __future__ import print_function, division, unicode_literals, absolute_import from builtins import str, open, next, zip, range import os @@ -44,7 +43,12 @@ from ...utils.misc import str2bool from ...utils.functions import create_function_from_source from ...interfaces.base.traits_extension import ( - rebase_path_traits, resolve_path_traits, OutputMultiPath, isdefined, Undefined) + rebase_path_traits, + resolve_path_traits, + OutputMultiPath, + isdefined, + Undefined, +) from ...interfaces.base.support import Bunch, InterfaceResult from ...interfaces.base import CommandLine from ...interfaces.utility import IdentityInterface @@ -56,7 +60,7 @@ from funcsigs import signature standard_library.install_aliases() -logger = logging.getLogger('nipype.workflow') +logger = logging.getLogger("nipype.workflow") PY3 = sys.version_info[0] > 2 @@ -82,13 +86,12 @@ def save_hashfile(hashfile, hashed_inputs): # XXX - SG current workaround is to just # create the hashed file and not put anything # in it - with open(hashfile, 'wt') as fd: + with open(hashfile, "wt") as fd: fd.writelines(str(hashed_inputs)) - logger.debug('Unable to write a particular type to the json file') + logger.debug("Unable to write a particular type to the json file") else: - logger.critical('Unable to open the file in write mode: %s', - hashfile) + logger.critical("Unable to open the file in write mode: %s", hashfile) def nodelist_runner(nodes, updatehash=False, stop_first=False): @@ -108,49 +111,47 @@ def nodelist_runner(nodes, updatehash=False, stop_first=False): result = node.result err = [] - if result.runtime and hasattr(result.runtime, 'traceback'): + if result.runtime and hasattr(result.runtime, "traceback"): err = [result.runtime.traceback] err += format_exception(*sys.exc_info()) - err = '\n'.join(err) + err = "\n".join(err) finally: yield i, result, err def write_node_report(node, result=None, is_mapnode=False): """Write a report file for a node.""" - if not str2bool(node.config['execution']['create_report']): + if not str2bool(node.config["execution"]["create_report"]): return cwd = node.output_dir() - report_file = Path(cwd) / '_report' / 'report.rst' + report_file = Path(cwd) / "_report" / "report.rst" path_mkdir(report_file.parent, exist_ok=True, parents=True) lines = [ - write_rst_header('Node: %s' % get_print_name(node), level=0), - write_rst_list( - ['Hierarchy : %s' % node.fullname, - 'Exec ID : %s' % node._id]), - write_rst_header('Original Inputs', level=1), + write_rst_header("Node: %s" % get_print_name(node), level=0), + write_rst_list(["Hierarchy : %s" % node.fullname, "Exec ID : %s" % node._id]), + write_rst_header("Original Inputs", level=1), write_rst_dict(node.inputs.trait_get()), ] if result is None: logger.debug('[Node] Writing pre-exec report to "%s"', report_file) - report_file.write_text('\n'.join(lines)) + report_file.write_text("\n".join(lines)) return logger.debug('[Node] Writing post-exec report to "%s"', report_file) lines += [ - write_rst_header('Execution Inputs', level=1), + write_rst_header("Execution Inputs", level=1), write_rst_dict(node.inputs.trait_get()), - write_rst_header('Execution Outputs', level=1) + write_rst_header("Execution Outputs", level=1), ] outputs = result.outputs if outputs is None: - lines += ['None'] - report_file.write_text('\n'.join(lines)) + lines += ["None"] + report_file.write_text("\n".join(lines)) return if isinstance(outputs, Bunch): @@ -158,90 +159,98 @@ def write_node_report(node, result=None, is_mapnode=False): elif outputs: lines.append(write_rst_dict(outputs.trait_get())) else: - lines += ['Outputs object was empty.'] + lines += ["Outputs object was empty."] if is_mapnode: - lines.append(write_rst_header('Subnode reports', level=1)) + lines.append(write_rst_header("Subnode reports", level=1)) nitems = len(ensure_list(getattr(node.inputs, node.iterfield[0]))) subnode_report_files = [] for i in range(nitems): - subnode_file = Path(cwd) / 'mapflow' / ( - '_%s%d' % (node.name, i)) / '_report' / 'report.rst' - subnode_report_files.append('subnode %d : %s' % (i, subnode_file)) + subnode_file = ( + Path(cwd) + / "mapflow" + / ("_%s%d" % (node.name, i)) + / "_report" + / "report.rst" + ) + subnode_report_files.append("subnode %d : %s" % (i, subnode_file)) lines.append(write_rst_list(subnode_report_files)) - report_file.write_text('\n'.join(lines)) + report_file.write_text("\n".join(lines)) return - lines.append(write_rst_header('Runtime info', level=1)) + lines.append(write_rst_header("Runtime info", level=1)) # Init rst dictionary of runtime stats rst_dict = { - 'hostname': result.runtime.hostname, - 'duration': result.runtime.duration, - 'working_dir': result.runtime.cwd, - 'prev_wd': getattr(result.runtime, 'prevcwd', ''), + "hostname": result.runtime.hostname, + "duration": result.runtime.duration, + "working_dir": result.runtime.cwd, + "prev_wd": getattr(result.runtime, "prevcwd", ""), } - for prop in ('cmdline', 'mem_peak_gb', 'cpu_percent'): + for prop in ("cmdline", "mem_peak_gb", "cpu_percent"): if hasattr(result.runtime, prop): rst_dict[prop] = getattr(result.runtime, prop) lines.append(write_rst_dict(rst_dict)) # Collect terminal output - if hasattr(result.runtime, 'merged'): + if hasattr(result.runtime, "merged"): lines += [ - write_rst_header('Terminal output', level=2), + write_rst_header("Terminal output", level=2), write_rst_list(result.runtime.merged), ] - if hasattr(result.runtime, 'stdout'): + if hasattr(result.runtime, "stdout"): lines += [ - write_rst_header('Terminal - standard output', level=2), + write_rst_header("Terminal - standard output", level=2), write_rst_list(result.runtime.stdout), ] - if hasattr(result.runtime, 'stderr'): + if hasattr(result.runtime, "stderr"): lines += [ - write_rst_header('Terminal - standard error', level=2), + write_rst_header("Terminal - standard error", level=2), write_rst_list(result.runtime.stderr), ] # Store environment - if hasattr(result.runtime, 'environ'): + if hasattr(result.runtime, "environ"): lines += [ - write_rst_header('Environment', level=2), + write_rst_header("Environment", level=2), write_rst_dict(result.runtime.environ), ] - report_file.write_text('\n'.join(lines)) + report_file.write_text("\n".join(lines)) def write_report(node, report_type=None, is_mapnode=False): """Write a report file for a node - DEPRECATED""" - if report_type not in ('preexec', 'postexec'): + if report_type not in ("preexec", "postexec"): logger.warning('[Node] Unknown report type "%s".', report_type) return - write_node_report(node, is_mapnode=is_mapnode, - result=node.result if report_type == 'postexec' else None) + write_node_report( + node, + is_mapnode=is_mapnode, + result=node.result if report_type == "postexec" else None, + ) def save_resultfile(result, cwd, name, rebase=None): """Save a result pklz file to ``cwd``.""" if rebase is None: - rebase = config.getboolean('execution', 'use_relative_paths') + rebase = config.getboolean("execution", "use_relative_paths") cwd = os.path.abspath(cwd) - resultsfile = os.path.join(cwd, 'result_%s.pklz' % name) + resultsfile = os.path.join(cwd, "result_%s.pklz" % name) logger.debug("Saving results file: '%s'", resultsfile) if result.outputs is None: - logger.warning('Storing result file without outputs') + logger.warning("Storing result file without outputs") savepkl(resultsfile, result) return try: output_names = result.outputs.copyable_trait_names() except AttributeError: - logger.debug('Storing non-traited results, skipping rebase of paths') + logger.debug("Storing non-traited results, skipping rebase of paths") savepkl(resultsfile, result) return @@ -258,7 +267,8 @@ def save_resultfile(result, cwd, name, rebase=None): if isdefined(old): if result.outputs.trait(key).is_trait_type(OutputMultiPath): old = result.outputs.trait(key).handler.get_value( - result.outputs, key) + result.outputs, key + ) backup_traits[key] = old val = rebase_path_traits(result.outputs.trait(key), old, cwd) setattr(result.outputs, key, val) @@ -298,17 +308,19 @@ def load_resultfile(results_file, resolve=True): try: outputs = result.outputs.get() except TypeError: # This is a Bunch - logger.debug('Outputs object of loaded result %s is a Bunch.', results_file) + logger.debug("Outputs object of loaded result %s is a Bunch.", results_file) return result - logger.debug('Resolving paths in outputs loaded from results file.') + logger.debug("Resolving paths in outputs loaded from results file.") for trait_name, old in list(outputs.items()): if isdefined(old): if result.outputs.trait(trait_name).is_trait_type(OutputMultiPath): old = result.outputs.trait(trait_name).handler.get_value( - result.outputs, trait_name) - value = resolve_path_traits(result.outputs.trait(trait_name), old, - results_file.parent) + result.outputs, trait_name + ) + value = resolve_path_traits( + result.outputs.trait(trait_name), old, results_file.parent + ) setattr(result.outputs, trait_name, value) return result @@ -320,13 +332,13 @@ def strip_temp(files, wd): if isinstance(f, list): out.append(strip_temp(f, wd)) else: - out.append(f.replace(os.path.join(wd, '_tempinput'), wd)) + out.append(f.replace(os.path.join(wd, "_tempinput"), wd)) return out def _write_inputs(node): lines = [] - nodename = node.fullname.replace('.', '_') + nodename = node.fullname.replace(".", "_") for key, _ in list(node.inputs.items()): val = getattr(node.inputs, key) if isdefined(val): @@ -337,65 +349,72 @@ def _write_inputs(node): lines.append("%s.inputs.%s = '%s'" % (nodename, key, val)) else: funcname = [ - name for name in func.__globals__ - if name != '__builtins__' + name for name in func.__globals__ if name != "__builtins__" ][0] lines.append(pickle.loads(val)) if funcname == nodename: - lines[-1] = lines[-1].replace(' %s(' % funcname, - ' %s_1(' % funcname) - funcname = '%s_1' % funcname + lines[-1] = lines[-1].replace( + " %s(" % funcname, " %s_1(" % funcname + ) + funcname = "%s_1" % funcname + lines.append("from nipype.utils.functions import getsource") lines.append( - 'from nipype.utils.functions import getsource') - lines.append("%s.inputs.%s = getsource(%s)" % - (nodename, key, funcname)) + "%s.inputs.%s = getsource(%s)" % (nodename, key, funcname) + ) else: - lines.append('%s.inputs.%s = %s' % (nodename, key, val)) + lines.append("%s.inputs.%s = %s" % (nodename, key, val)) return lines -def format_node(node, format='python', include_config=False): +def format_node(node, format="python", include_config=False): """Format a node in a given output syntax.""" from .nodes import MapNode + lines = [] - name = node.fullname.replace('.', '_') - if format == 'python': + name = node.fullname.replace(".", "_") + if format == "python": klass = node.interface - importline = 'from %s import %s' % (klass.__module__, - klass.__class__.__name__) - comment = '# Node: %s' % node.fullname + importline = "from %s import %s" % (klass.__module__, klass.__class__.__name__) + comment = "# Node: %s" % node.fullname spec = signature(node.interface.__init__) args = [p.name for p in list(spec.parameters.values())] args = args[1:] if args: filled_args = [] for arg in args: - if hasattr(node.interface, '_%s' % arg): - filled_args.append('%s=%s' % - (arg, - getattr(node.interface, '_%s' % arg))) - args = ', '.join(filled_args) + if hasattr(node.interface, "_%s" % arg): + filled_args.append( + "%s=%s" % (arg, getattr(node.interface, "_%s" % arg)) + ) + args = ", ".join(filled_args) else: - args = '' + args = "" klass_name = klass.__class__.__name__ if isinstance(node, MapNode): - nodedef = '%s = MapNode(%s(%s), iterfield=%s, name="%s")' \ - % (name, klass_name, args, node.iterfield, name) + nodedef = '%s = MapNode(%s(%s), iterfield=%s, name="%s")' % ( + name, + klass_name, + args, + node.iterfield, + name, + ) else: - nodedef = '%s = Node(%s(%s), name="%s")' \ - % (name, klass_name, args, name) + nodedef = '%s = Node(%s(%s), name="%s")' % (name, klass_name, args, name) lines = [importline, comment, nodedef] if include_config: lines = [ - importline, "from future import standard_library", + importline, + "from future import standard_library", "standard_library.install_aliases()", - "from collections import OrderedDict", comment, nodedef + "from collections import OrderedDict", + comment, + nodedef, ] - lines.append('%s.config = %s' % (name, node.config)) + lines.append("%s.config = %s" % (name, node.config)) if node.iterables is not None: - lines.append('%s.iterables = %s' % (name, node.iterables)) + lines.append("%s.iterables = %s" % (name, node.iterables)) lines.extend(_write_inputs(node)) return lines @@ -420,28 +439,26 @@ def modify_paths(object, relative=True, basedir=None): out = {} for key, val in sorted(object.items()): if isdefined(val): - out[key] = modify_paths( - val, relative=relative, basedir=basedir) + out[key] = modify_paths(val, relative=relative, basedir=basedir) elif isinstance(object, (list, tuple)): out = [] for val in object: if isdefined(val): - out.append( - modify_paths(val, relative=relative, basedir=basedir)) + out.append(modify_paths(val, relative=relative, basedir=basedir)) if isinstance(object, tuple): out = tuple(out) else: if isdefined(object): if isinstance(object, (str, bytes)) and os.path.isfile(object): if relative: - if config.getboolean('execution', 'use_relative_paths'): + if config.getboolean("execution", "use_relative_paths"): out = relpath(object, start=basedir) else: out = object else: out = os.path.abspath(os.path.join(basedir, object)) if not os.path.exists(out): - raise IOError('File %s not found' % out) + raise IOError("File %s not found" % out) else: out = object else: @@ -457,20 +474,20 @@ def get_print_name(node, simple_form=True): """ name = node.fullname - if hasattr(node, '_interface'): - pkglist = node.interface.__class__.__module__.split('.') + if hasattr(node, "_interface"): + pkglist = node.interface.__class__.__module__.split(".") interface = node.interface.__class__.__name__ - destclass = '' + destclass = "" if len(pkglist) > 2: - destclass = '.%s' % pkglist[2] + destclass = ".%s" % pkglist[2] if simple_form: name = node.fullname + destclass else: - name = '.'.join([node.fullname, interface]) + destclass + name = ".".join([node.fullname, interface]) + destclass if simple_form: - parts = name.split('.') + parts = name.split(".") if len(parts) > 2: - return ' ('.join(parts[1:]) + ')' + return " (".join(parts[1:]) + ")" elif len(parts) == 2: return parts[1] return name @@ -481,15 +498,16 @@ def _create_dot_graph(graph, show_connectinfo=False, simple_form=True): Ensures that edge info is pickleable. """ - logger.debug('creating dot graph') + logger.debug("creating dot graph") import networkx as nx + pklgraph = nx.DiGraph() for edge in graph.edges(): data = graph.get_edge_data(*edge) srcname = get_print_name(edge[0], simple_form=simple_form) destname = get_print_name(edge[1], simple_form=simple_form) if show_connectinfo: - pklgraph.add_edge(srcname, destname, l=str(data['connect'])) + pklgraph.add_edge(srcname, destname, l=str(data["connect"])) else: pklgraph.add_edge(srcname, destname) return pklgraph @@ -510,67 +528,86 @@ def _write_detailed_dot(graph, dotfilename): } """ import networkx as nx - text = ['digraph structs {', 'node [shape=record];'] + + text = ["digraph structs {", "node [shape=record];"] # write nodes edges = [] for n in nx.topological_sort(graph): nodename = n.itername inports = [] for u, v, d in graph.in_edges(nbunch=n, data=True): - for cd in d['connect']: + for cd in d["connect"]: if isinstance(cd[0], (str, bytes)): outport = cd[0] else: outport = cd[0][0] inport = cd[1] - ipstrip = 'in%s' % _replacefunk(inport) - opstrip = 'out%s' % _replacefunk(outport) + ipstrip = "in%s" % _replacefunk(inport) + opstrip = "out%s" % _replacefunk(outport) edges.append( - '%s:%s:e -> %s:%s:w;' % (u.itername.replace('.', ''), opstrip, - v.itername.replace('.', ''), ipstrip)) + "%s:%s:e -> %s:%s:w;" + % ( + u.itername.replace(".", ""), + opstrip, + v.itername.replace(".", ""), + ipstrip, + ) + ) if inport not in inports: inports.append(inport) - inputstr = ['{IN'] + [ - '| %s' % (_replacefunk(ip), ip) for ip in sorted(inports) - ] + ['}'] + inputstr = ( + ["{IN"] + + ["| %s" % (_replacefunk(ip), ip) for ip in sorted(inports)] + + ["}"] + ) outports = [] for u, v, d in graph.out_edges(nbunch=n, data=True): - for cd in d['connect']: + for cd in d["connect"]: if isinstance(cd[0], (str, bytes)): outport = cd[0] else: outport = cd[0][0] if outport not in outports: outports.append(outport) - outputstr = ['{OUT'] + [ - '| %s' % (_replacefunk(oport), oport) - for oport in sorted(outports) - ] + ['}'] - srcpackage = '' - if hasattr(n, '_interface'): - pkglist = n.interface.__class__.__module__.split('.') + outputstr = ( + ["{OUT"] + + [ + "| %s" % (_replacefunk(oport), oport) + for oport in sorted(outports) + ] + + ["}"] + ) + srcpackage = "" + if hasattr(n, "_interface"): + pkglist = n.interface.__class__.__module__.split(".") if len(pkglist) > 2: srcpackage = pkglist[2] - srchierarchy = '.'.join(nodename.split('.')[1:-1]) - nodenamestr = '{ %s | %s | %s }' % (nodename.split('.')[-1], - srcpackage, srchierarchy) + srchierarchy = ".".join(nodename.split(".")[1:-1]) + nodenamestr = "{ %s | %s | %s }" % ( + nodename.split(".")[-1], + srcpackage, + srchierarchy, + ) text += [ - '%s [label="%s|%s|%s"];' % - (nodename.replace('.', ''), ''.join(inputstr), nodenamestr, - ''.join(outputstr)) + '%s [label="%s|%s|%s"];' + % ( + nodename.replace(".", ""), + "".join(inputstr), + nodenamestr, + "".join(outputstr), + ) ] # write edges for edge in sorted(edges): text.append(edge) - text.append('}') - with open(dotfilename, 'wt') as filep: - filep.write('\n'.join(text)) + text.append("}") + with open(dotfilename, "wt") as filep: + filep.write("\n".join(text)) return text def _replacefunk(x): - return x.replace('_', '').replace('.', '').replace('@', '').replace( - '-', '') + return x.replace("_", "").replace(".", "").replace("@", "").replace("-", "") # Graph manipulations for iterable expansion @@ -582,9 +619,9 @@ def _get_valid_pathstr(pathstr): """ if not isinstance(pathstr, (str, bytes)): pathstr = to_str(pathstr) - pathstr = pathstr.replace(os.sep, '..') - pathstr = re.sub(r'''[][ (){}?:<>#!|"';]''', '', pathstr) - pathstr = pathstr.replace(',', '.') + pathstr = pathstr.replace(os.sep, "..") + pathstr = re.sub(r"""[][ (){}?:<>#!|"';]""", "", pathstr) + pathstr = pathstr.replace(",", ".") return pathstr @@ -657,8 +694,9 @@ def synchronize_iterables(iterables): True """ out_list = [] - iterable_items = [(field, iter(fvals())) - for field, fvals in sorted(iterables.items())] + iterable_items = [ + (field, iter(fvals())) for field, fvals in sorted(iterables.items()) + ] while True: cur_dict = {} for field, iter_values in iterable_items: @@ -679,17 +717,21 @@ def evaluate_connect_function(function_source, args, first_arg): try: output_value = func(first_arg, *list(args)) except NameError as e: - if e.args[0].startswith("global name") and \ - e.args[0].endswith("is not defined"): - e.args = (e.args[0], - ("Due to engine constraints all imports have to be done " - "inside each function definition")) + if e.args[0].startswith("global name") and e.args[0].endswith("is not defined"): + e.args = ( + e.args[0], + ( + "Due to engine constraints all imports have to be done " + "inside each function definition" + ), + ) raise e return output_value def get_levels(G): import networkx as nx + levels = {} for n in nx.topological_sort(G): levels[n] = 0 @@ -698,13 +740,9 @@ def get_levels(G): return levels -def _merge_graphs(supergraph, - nodes, - subgraph, - nodeid, - iterables, - prefix, - synchronize=False): +def _merge_graphs( + supergraph, nodes, subgraph, nodeid, iterables, prefix, synchronize=False +): """Merges two graphs that share a subset of nodes. If the subgraph needs to be replicated for multiple iterables, the @@ -740,8 +778,12 @@ def _merge_graphs(supergraph, # This should trap the problem of miswiring when multiple iterables are # used at the same level. The use of the template below for naming # updates to nodes is the general solution. - raise Exception(("Execution graph does not have a unique set of node " - "names. Please rerun the workflow")) + raise Exception( + ( + "Execution graph does not have a unique set of node " + "names. Please rerun the workflow" + ) + ) edgeinfo = {} for n in list(subgraph.nodes()): nidx = ids.index(n._hierarchy + n._id) @@ -751,7 +793,8 @@ def _merge_graphs(supergraph, if n._hierarchy + n._id not in list(edgeinfo.keys()): edgeinfo[n._hierarchy + n._id] = [] edgeinfo[n._hierarchy + n._id].append( - (edge[0], supergraph.get_edge_data(*edge))) + (edge[0], supergraph.get_edge_data(*edge)) + ) supergraph.remove_nodes_from(nodes) # Add copies of the subgraph depending on the number of iterables iterable_params = expand_iterables(iterables, synchronize) @@ -760,20 +803,21 @@ def _merge_graphs(supergraph, return supergraph # Make an iterable subgraph node id template count = len(iterable_params) - template = '.%s%%0%dd' % (prefix, np.ceil(np.log10(count))) + template = ".%s%%0%dd" % (prefix, np.ceil(np.log10(count))) # Copy the iterable subgraphs for i, params in enumerate(iterable_params): Gc = deepcopy(subgraph) ids = [n._hierarchy + n._id for n in Gc.nodes()] nodeidx = ids.index(nodeid) rootnode = list(Gc.nodes())[nodeidx] - paramstr = '' + paramstr = "" for key, val in sorted(params.items()): - paramstr = '{}_{}_{}'.format(paramstr, _get_valid_pathstr(key), - _get_valid_pathstr(val)) + paramstr = "{}_{}_{}".format( + paramstr, _get_valid_pathstr(key), _get_valid_pathstr(val) + ) rootnode.set_input(key, val) - logger.debug('Parameterization: paramstr=%s', paramstr) + logger.debug("Parameterization: paramstr=%s", paramstr) levels = get_levels(Gc) for n in Gc.nodes(): # update parameterization of the node to reflect the location of @@ -806,10 +850,10 @@ def _connect_nodes(graph, srcnode, destnode, connection_info): """ data = graph.get_edge_data(srcnode, destnode, default=None) if not data: - data = {'connect': connection_info} + data = {"connect": connection_info} graph.add_edges_from([(srcnode, destnode, data)]) else: - data['connect'].extend(connection_info) + data["connect"].extend(connection_info) def _remove_nonjoin_identity_nodes(graph, keep_iterables=False): @@ -821,7 +865,7 @@ def _remove_nonjoin_identity_nodes(graph, keep_iterables=False): # if keep_iterables is False, then include the iterable # and join nodes in the nodes to delete for node in _identity_nodes(graph, not keep_iterables): - if not hasattr(node, 'joinsource'): + if not hasattr(node, "joinsource"): _remove_identity_node(graph, node) return graph @@ -834,10 +878,12 @@ def _identity_nodes(graph, include_iterables): to True. """ import networkx as nx + return [ - node for node in nx.topological_sort(graph) - if isinstance(node.interface, IdentityInterface) and ( - include_iterables or getattr(node, 'iterables') is None) + node + for node in nx.topological_sort(graph) + if isinstance(node.interface, IdentityInterface) + and (include_iterables or getattr(node, "iterables") is None) ] @@ -847,8 +893,7 @@ def _remove_identity_node(graph, node): portinputs, portoutputs = _node_ports(graph, node) for field, connections in list(portoutputs.items()): if portinputs: - _propagate_internal_output(graph, node, field, connections, - portinputs) + _propagate_internal_output(graph, node, field, connections, portinputs) else: _propagate_root_output(graph, node, field, connections) graph.remove_nodes_from([node]) @@ -868,10 +913,10 @@ def _node_ports(graph, node): portinputs = {} portoutputs = {} for u, _, d in graph.in_edges(node, data=True): - for src, dest in d['connect']: + for src, dest in d["connect"]: portinputs[dest] = (u, src) for _, v, d in graph.out_edges(node, data=True): - for src, dest in d['connect']: + for src, dest in d["connect"]: if isinstance(src, tuple): srcport = src[0] else: @@ -902,25 +947,22 @@ def _propagate_internal_output(graph, node, field, connections, portinputs): if isinstance(srcport, tuple) and isinstance(src, tuple): src_func = srcport[1].split("\\n")[0] dst_func = src[1].split("\\n")[0] - raise ValueError("Does not support two inline functions " - "in series ('{}' and '{}'), found when " - "connecting {} to {}. Please use a Function " - "node.".format(src_func, dst_func, srcnode, - destnode)) - - connect = graph.get_edge_data( - srcnode, destnode, default={ - 'connect': [] - }) + raise ValueError( + "Does not support two inline functions " + "in series ('{}' and '{}'), found when " + "connecting {} to {}. Please use a Function " + "node.".format(src_func, dst_func, srcnode, destnode) + ) + + connect = graph.get_edge_data(srcnode, destnode, default={"connect": []}) if isinstance(src, tuple): - connect['connect'].append(((srcport, src[1], src[2]), inport)) + connect["connect"].append(((srcport, src[1], src[2]), inport)) else: - connect = {'connect': [(srcport, inport)]} + connect = {"connect": [(srcport, inport)]} old_connect = graph.get_edge_data( - srcnode, destnode, default={ - 'connect': [] - }) - old_connect['connect'] += connect['connect'] + srcnode, destnode, default={"connect": []} + ) + old_connect["connect"] += connect["connect"] graph.add_edges_from([(srcnode, destnode, old_connect)]) else: value = getattr(node.inputs, field) @@ -938,6 +980,7 @@ def generate_expanded_graph(graph_in): parameterized as (a=1,b=3), (a=1,b=4), (a=2,b=3) and (a=2,b=4). """ import networkx as nx + try: dfs_preorder = nx.dfs_preorder except AttributeError: @@ -949,7 +992,7 @@ def generate_expanded_graph(graph_in): for node in graph_in.nodes(): if node.iterables: _standardize_iterables(node) - allprefixes = list('abcdefghijklmnopqrstuvwxyz') + allprefixes = list("abcdefghijklmnopqrstuvwxyz") # the iterable nodes inodes = _iterable_nodes(graph_in) @@ -962,8 +1005,10 @@ def generate_expanded_graph(graph_in): # the join successor nodes of the current iterable node jnodes = [ - node for node in graph_in.nodes() - if hasattr(node, 'joinsource') and inode.name == node.joinsource + node + for node in graph_in.nodes() + if hasattr(node, "joinsource") + and inode.name == node.joinsource and nx.has_path(graph_in, inode, node) ] @@ -980,8 +1025,7 @@ def generate_expanded_graph(graph_in): for src, dest in edges2remove: graph_in.remove_edge(src, dest) - logger.debug("Excised the %s -> %s join node in-edge.", src, - dest) + logger.debug("Excised the %s -> %s join node in-edge.", src, dest) if inode.itersource: # the itersource is a (node name, fields) tuple @@ -991,22 +1035,24 @@ def generate_expanded_graph(graph_in): src_fields = [src_fields] # find the unique iterable source node in the graph try: - iter_src = next((node for node in graph_in.nodes() - if node.name == src_name - and nx.has_path(graph_in, node, inode))) + iter_src = next( + ( + node + for node in graph_in.nodes() + if node.name == src_name and nx.has_path(graph_in, node, inode) + ) + ) except StopIteration: - raise ValueError("The node %s itersource %s was not found" - " among the iterable predecessor nodes" % - (inode, src_name)) - logger.debug("The node %s has iterable source node %s", inode, - iter_src) + raise ValueError( + "The node %s itersource %s was not found" + " among the iterable predecessor nodes" % (inode, src_name) + ) + logger.debug("The node %s has iterable source node %s", inode, iter_src) # look up the iterables for this particular itersource descendant # using the iterable source ancestor values as a key iterables = {} # the source node iterables values - src_values = [ - getattr(iter_src.inputs, field) for field in src_fields - ] + src_values = [getattr(iter_src.inputs, field) for field in src_fields] # if there is one source field, then the key is the the source value, # otherwise the key is the tuple of source values if len(src_values) == 1: @@ -1016,9 +1062,13 @@ def generate_expanded_graph(graph_in): # The itersource iterables is a {field: lookup} dictionary, where the # lookup is a {source key: iteration list} dictionary. Look up the # current iterable value using the predecessor itersource input values. - iter_dict = dict([(field, lookup[key]) - for field, lookup in inode.iterables - if key in lookup]) + iter_dict = dict( + [ + (field, lookup[key]) + for field, lookup in inode.iterables + if key in lookup + ] + ) # convert the iterables to the standard {field: function} format @@ -1026,37 +1076,43 @@ def make_field_func(*pair): return pair[0], lambda: pair[1] iterables = dict( - [make_field_func(*pair) for pair in list(iter_dict.items())]) + [make_field_func(*pair) for pair in list(iter_dict.items())] + ) else: iterables = inode.iterables.copy() inode.iterables = None - logger.debug('node: %s iterables: %s', inode, iterables) + logger.debug("node: %s iterables: %s", inode, iterables) # collect the subnodes to expand subnodes = [s for s in dfs_preorder(graph_in, inode)] - prior_prefix = [re.findall(r'\.(.)I', s._id) for s in subnodes if s._id] + prior_prefix = [re.findall(r"\.(.)I", s._id) for s in subnodes if s._id] prior_prefix = sorted([l for item in prior_prefix for l in item]) if not prior_prefix: - iterable_prefix = 'a' + iterable_prefix = "a" else: - if prior_prefix[-1] == 'z': - raise ValueError('Too many iterables in the workflow') - iterable_prefix =\ - allprefixes[allprefixes.index(prior_prefix[-1]) + 1] - logger.debug(('subnodes:', subnodes)) + if prior_prefix[-1] == "z": + raise ValueError("Too many iterables in the workflow") + iterable_prefix = allprefixes[allprefixes.index(prior_prefix[-1]) + 1] + logger.debug(("subnodes:", subnodes)) # append a suffix to the iterable node id - inode._id += '.%sI' % iterable_prefix + inode._id += ".%sI" % iterable_prefix # merge the iterated subgraphs # dj: the behaviour of .copy changes in version 2 - if LooseVersion(nx.__version__) < LooseVersion('2'): + if LooseVersion(nx.__version__) < LooseVersion("2"): subgraph = graph_in.subgraph(subnodes) else: subgraph = graph_in.subgraph(subnodes).copy() - graph_in = _merge_graphs(graph_in, subnodes, subgraph, - inode._hierarchy + inode._id, iterables, - iterable_prefix, inode.synchronize) + graph_in = _merge_graphs( + graph_in, + subnodes, + subgraph, + inode._hierarchy + inode._id, + iterables, + iterable_prefix, + inode.synchronize, + ) # reconnect the join nodes for jnode in jnodes: @@ -1069,7 +1125,7 @@ def make_field_func(*pair): for src_id in list(old_edge_dict.keys()): # Drop the original JoinNodes; only concerned with # generated Nodes - if hasattr(node, 'joinfield') and node.itername == src_id: + if hasattr(node, "joinfield") and node.itername == src_id: continue # Patterns: # - src_id : Non-iterable node @@ -1078,12 +1134,17 @@ def make_field_func(*pair): # - src_id.[a-z]I.[a-z]\d+ : # Non-IdentityInterface w/ iterables # - src_idJ\d+ : JoinNode(IdentityInterface) - if re.match(src_id + r'((\.[a-z](I\.[a-z])?|J)\d+)?$', - node.itername): + if re.match( + src_id + r"((\.[a-z](I\.[a-z])?|J)\d+)?$", node.itername + ): expansions[src_id].append(node) for in_id, in_nodes in list(expansions.items()): - logger.debug("The join node %s input %s was expanded" - " to %d nodes.", jnode, in_id, len(in_nodes)) + logger.debug( + "The join node %s input %s was expanded" " to %d nodes.", + jnode, + in_id, + len(in_nodes), + ) # preserve the node iteration order by sorting on the node id for in_nodes in list(expansions.values()): in_nodes.sort(key=lambda node: node._id) @@ -1092,9 +1153,7 @@ def make_field_func(*pair): iter_cnt = count_iterables(iterables, inode.synchronize) # make new join node fields to connect to each replicated # join in-edge source node. - slot_dicts = [ - jnode._add_join_item_fields() for _ in range(iter_cnt) - ] + slot_dicts = [jnode._add_join_item_fields() for _ in range(iter_cnt)] # for each join in-edge, connect every expanded source node # which matches on the in-edge source name to the destination # join node. Qualify each edge connect join field name by @@ -1110,11 +1169,10 @@ def make_field_func(*pair): olddata = old_edge_dict[old_id] newdata = deepcopy(olddata) # the (source, destination) field tuples - connects = newdata['connect'] + connects = newdata["connect"] # the join fields connected to the source join_fields = [ - field for _, field in connects - if field in jnode.joinfield + field for _, field in connects if field in jnode.joinfield ] # the {field: slot fields} maps assigned to the input # node, e.g. {'image': 'imageJ3', 'mask': 'maskJ3'} @@ -1129,10 +1187,18 @@ def make_field_func(*pair): connects[con_idx] = (src_field, slot_field) logger.debug( "Qualified the %s -> %s join field %s as %s.", - in_node, jnode, dest_field, slot_field) + in_node, + jnode, + dest_field, + slot_field, + ) graph_in.add_edge(in_node, jnode, **newdata) - logger.debug("Connected the join node %s subgraph to the" - " expanded join point %s", jnode, in_node) + logger.debug( + "Connected the join node %s subgraph to the" + " expanded join point %s", + jnode, + in_node, + ) # nx.write_dot(graph_in, '%s_post.dot' % node) # the remaining iterable nodes @@ -1172,6 +1238,7 @@ def _iterable_nodes(graph_in): Return the iterable nodes list """ import networkx as nx + nodes = nx.topological_sort(graph_in) inodes = [node for node in nodes if node.iterables is not None] inodes_no_src = [node for node in inodes if not node.itersource] @@ -1194,8 +1261,9 @@ def _standardize_iterables(node): if node.synchronize: if len(iterables) == 2: first, last = iterables - if all((isinstance(item, (str, bytes)) and item in fields - for item in first)): + if all( + (isinstance(item, (str, bytes)) and item in fields for item in first) + ): iterables = _transpose_iterables(first, last) # Convert a tuple to a list @@ -1212,9 +1280,7 @@ def _standardize_iterables(node): def make_field_func(*pair): return pair[0], lambda: pair[1] - iter_items = [ - make_field_func(*field_value1) for field_value1 in iterables - ] + iter_items = [make_field_func(*field_value1) for field_value1 in iterables] iterables = dict(iter_items) node.iterables = iterables @@ -1231,20 +1297,25 @@ def _validate_iterables(node, iterables, fields): if isinstance(iterables, dict): iterables = list(iterables.items()) elif not isinstance(iterables, tuple) and not isinstance(iterables, list): - raise ValueError("The %s iterables type is not a list or a dictionary:" - " %s" % (node.name, iterables.__class__)) + raise ValueError( + "The %s iterables type is not a list or a dictionary:" + " %s" % (node.name, iterables.__class__) + ) for item in iterables: try: if len(item) != 2: - raise ValueError("The %s iterables is not a [(field, values)]" - " list" % node.name) + raise ValueError( + "The %s iterables is not a [(field, values)]" " list" % node.name + ) except TypeError as e: - raise TypeError("A %s iterables member is not iterable: %s" % - (node.name, e)) + raise TypeError( + "A %s iterables member is not iterable: %s" % (node.name, e) + ) field, _ = item if field not in fields: - raise ValueError("The %s iterables field is unrecognized: %s" % - (node.name, field)) + raise ValueError( + "The %s iterables field is unrecognized: %s" % (node.name, field) + ) def _transpose_iterables(fields, values): @@ -1267,18 +1338,26 @@ def _transpose_iterables(fields, values): return list(transposed.items()) return list( - zip(fields, [[v for v in list(transpose) if v is not None] - for transpose in zip(*values)])) - - -def export_graph(graph_in, - base_dir=None, - show=False, - use_execgraph=False, - show_connectinfo=False, - dotfilename='graph.dot', - format='png', - simple_form=True): + zip( + fields, + [ + [v for v in list(transpose) if v is not None] + for transpose in zip(*values) + ], + ) + ) + + +def export_graph( + graph_in, + base_dir=None, + show=False, + use_execgraph=False, + show_connectinfo=False, + dotfilename="graph.dot", + format="png", + simple_form=True, +): """ Displays the graph layout of the pipeline This function requires that pygraphviz and matplotlib are available on @@ -1300,37 +1379,40 @@ def export_graph(graph_in, makes the graph rather cluttered. default [False] """ import networkx as nx + graph = deepcopy(graph_in) if use_execgraph: graph = generate_expanded_graph(graph) - logger.debug('using execgraph') + logger.debug("using execgraph") else: - logger.debug('using input graph') + logger.debug("using input graph") if base_dir is None: base_dir = os.getcwd() makedirs(base_dir, exist_ok=True) out_dot = fname_presuffix( - dotfilename, suffix='_detailed.dot', use_ext=False, newpath=base_dir) + dotfilename, suffix="_detailed.dot", use_ext=False, newpath=base_dir + ) _write_detailed_dot(graph, out_dot) # Convert .dot if format != 'dot' outfname, res = _run_dot(out_dot, format_ext=format) if res is not None and res.runtime.returncode: - logger.warning('dot2png: %s', res.runtime.stderr) + logger.warning("dot2png: %s", res.runtime.stderr) pklgraph = _create_dot_graph(graph, show_connectinfo, simple_form) simple_dot = fname_presuffix( - dotfilename, suffix='.dot', use_ext=False, newpath=base_dir) + dotfilename, suffix=".dot", use_ext=False, newpath=base_dir + ) nx.drawing.nx_pydot.write_dot(pklgraph, simple_dot) # Convert .dot if format != 'dot' simplefname, res = _run_dot(simple_dot, format_ext=format) if res is not None and res.runtime.returncode: - logger.warning('dot2png: %s', res.runtime.stderr) + logger.warning("dot2png: %s", res.runtime.stderr) if show: - pos = nx.graphviz_layout(pklgraph, prog='dot') + pos = nx.graphviz_layout(pklgraph, prog="dot") nx.draw(pklgraph, pos) if show_connectinfo: nx.draw_networkx_edge_labels(pklgraph, pos) @@ -1338,7 +1420,7 @@ def export_graph(graph_in, return simplefname if simple_form else outfname -def format_dot(dotfilename, format='png'): +def format_dot(dotfilename, format="png"): """Dump a directed graph (Linux only; install via `brew` on OSX)""" try: formatted_dot, _ = _run_dot(dotfilename, format_ext=format) @@ -1351,14 +1433,13 @@ def format_dot(dotfilename, format='png'): def _run_dot(dotfilename, format_ext): - if format_ext == 'dot': + if format_ext == "dot": return dotfilename, None - dot_base = os.path.splitext(dotfilename)[0] - formatted_dot = '{}.{}'.format(dot_base, format_ext) + dot_base = os.path.splitext(dotfilename)[0] + formatted_dot = "{}.{}".format(dot_base, format_ext) cmd = 'dot -T{} -o"{}" "{}"'.format(format_ext, formatted_dot, dotfilename) - res = CommandLine(cmd, terminal_output='allatonce', - resource_monitor=False).run() + res = CommandLine(cmd, terminal_output="allatonce", resource_monitor=False).run() return formatted_dot, res @@ -1387,9 +1468,9 @@ def walk_outputs(object): else: if isdefined(object) and isinstance(object, (str, bytes)): if os.path.islink(object) or os.path.isfile(object): - out = [(filename, 'f') for filename in get_all_files(object)] + out = [(filename, "f") for filename in get_all_files(object)] elif os.path.isdir(object): - out = [(object, 'd')] + out = [(object, "d")] return out @@ -1399,53 +1480,54 @@ def walk_files(cwd): yield os.path.join(path, f) -def clean_working_directory(outputs, - cwd, - inputs, - needed_outputs, - config, - files2keep=None, - dirs2keep=None): +def clean_working_directory( + outputs, cwd, inputs, needed_outputs, config, files2keep=None, dirs2keep=None +): """Removes all files not needed for further analysis from the directory """ if not outputs: return outputs_to_keep = list(outputs.trait_get().keys()) - if needed_outputs and \ - str2bool(config['execution']['remove_unnecessary_outputs']): + if needed_outputs and str2bool(config["execution"]["remove_unnecessary_outputs"]): outputs_to_keep = needed_outputs # build a list of needed files output_files = [] outputdict = outputs.trait_get() for output in outputs_to_keep: output_files.extend(walk_outputs(outputdict[output])) - needed_files = [path for path, type in output_files if type == 'f'] - if str2bool(config['execution']['keep_inputs']): + needed_files = [path for path, type in output_files if type == "f"] + if str2bool(config["execution"]["keep_inputs"]): input_files = [] inputdict = inputs.trait_get() input_files.extend(walk_outputs(inputdict)) - needed_files += [path for path, type in input_files if type == 'f'] + needed_files += [path for path, type in input_files if type == "f"] for extra in [ - '_0x*.json', 'provenance.*', 'pyscript*.m', 'pyjobs*.mat', - 'command.txt', 'result*.pklz', '_inputs.pklz', '_node.pklz', - '.proc-*', + "_0x*.json", + "provenance.*", + "pyscript*.m", + "pyjobs*.mat", + "command.txt", + "result*.pklz", + "_inputs.pklz", + "_node.pklz", + ".proc-*", ]: needed_files.extend(glob(os.path.join(cwd, extra))) if files2keep: needed_files.extend(ensure_list(files2keep)) - needed_dirs = [path for path, type in output_files if type == 'd'] + needed_dirs = [path for path, type in output_files if type == "d"] if dirs2keep: needed_dirs.extend(ensure_list(dirs2keep)) - for extra in ['_nipype', '_report']: + for extra in ["_nipype", "_report"]: needed_dirs.extend(glob(os.path.join(cwd, extra))) temp = [] for filename in needed_files: temp.extend(get_related_files(filename)) needed_files = temp - logger.debug('Needed files: %s', ';'.join(needed_files)) - logger.debug('Needed dirs: %s', ';'.join(needed_dirs)) + logger.debug("Needed files: %s", ";".join(needed_files)) + logger.debug("Needed dirs: %s", ";".join(needed_dirs)) files2remove = [] - if str2bool(config['execution']['remove_unnecessary_outputs']): + if str2bool(config["execution"]["remove_unnecessary_outputs"]): for f in walk_files(cwd): if f not in needed_files: if not needed_dirs: @@ -1453,15 +1535,15 @@ def clean_working_directory(outputs, elif not any([f.startswith(dname) for dname in needed_dirs]): files2remove.append(f) else: - if not str2bool(config['execution']['keep_inputs']): + if not str2bool(config["execution"]["keep_inputs"]): input_files = [] inputdict = inputs.trait_get() input_files.extend(walk_outputs(inputdict)) - input_files = [path for path, type in input_files if type == 'f'] + input_files = [path for path, type in input_files if type == "f"] for f in walk_files(cwd): if f in input_files and f not in needed_files: files2remove.append(f) - logger.debug('Removing files: %s', ';'.join(files2remove)) + logger.debug("Removing files: %s", ";".join(files2remove)) for f in files2remove: os.remove(f) for key in outputs.copyable_trait_names(): @@ -1515,11 +1597,11 @@ def merge_bundles(g1, g2): return g1 -def write_workflow_prov(graph, filename=None, format='all'): +def write_workflow_prov(graph, filename=None, format="all"): """Write W3C PROV Model JSON file """ if not filename: - filename = os.path.join(os.getcwd(), 'workflow_provenance') + filename = os.path.join(os.getcwd(), "workflow_provenance") ps = ProvStore() @@ -1531,16 +1613,15 @@ def write_workflow_prov(graph, filename=None, format='all'): _, hashval, _, _ = node.hash_exists() attrs = { pm.PROV["type"]: nipype_ns[classname], - pm.PROV["label"]: '_'.join((classname, node.name)), - nipype_ns['hashval']: hashval + pm.PROV["label"]: "_".join((classname, node.name)), + nipype_ns["hashval"]: hashval, } process = ps.g.activity(get_id(), None, None, attrs) if isinstance(result.runtime, list): process.add_attributes({pm.PROV["type"]: nipype_ns["MapNode"]}) # add info about sub processes for idx, runtime in enumerate(result.runtime): - subresult = InterfaceResult( - result.interface[idx], runtime, outputs={}) + subresult = InterfaceResult(result.interface[idx], runtime, outputs={}) if result.inputs: if idx < len(result.inputs): subresult.inputs = result.inputs[idx] @@ -1550,14 +1631,12 @@ def write_workflow_prov(graph, filename=None, format='all'): if isdefined(values) and idx < len(values): subresult.outputs[key] = values[idx] sub_doc = ProvStore().add_results(subresult) - sub_bundle = pm.ProvBundle( - sub_doc.get_records(), identifier=get_id()) + sub_bundle = pm.ProvBundle(sub_doc.get_records(), identifier=get_id()) ps.g.add_bundle(sub_bundle) bundle_entity = ps.g.entity( sub_bundle.identifier, - other_attributes={ - 'prov:type': pm.PROV_BUNDLE - }) + other_attributes={"prov:type": pm.PROV_BUNDLE}, + ) ps.g.wasGeneratedBy(bundle_entity, process) else: process.add_attributes({pm.PROV["type"]: nipype_ns["Node"]}) @@ -1565,14 +1644,11 @@ def write_workflow_prov(graph, filename=None, format='all'): prov_doc = result.provenance else: prov_doc = ProvStore().add_results(result) - result_bundle = pm.ProvBundle( - prov_doc.get_records(), identifier=get_id()) + result_bundle = pm.ProvBundle(prov_doc.get_records(), identifier=get_id()) ps.g.add_bundle(result_bundle) bundle_entity = ps.g.entity( - result_bundle.identifier, - other_attributes={ - 'prov:type': pm.PROV_BUNDLE - }) + result_bundle.identifier, other_attributes={"prov:type": pm.PROV_BUNDLE} + ) ps.g.wasGeneratedBy(bundle_entity, process) processes.append(process) @@ -1581,7 +1657,8 @@ def write_workflow_prov(graph, filename=None, format='all'): for idx, edgeinfo in enumerate(graph.in_edges()): ps.g.wasStartedBy( processes[list(nodes).index(edgeinfo[1])], - starter=processes[list(nodes).index(edgeinfo[0])]) + starter=processes[list(nodes).index(edgeinfo[0])], + ) # write provenance ps.write_provenance(filename, format=format) @@ -1596,46 +1673,49 @@ def write_workflow_resources(graph, filename=None, append=None): import simplejson as json # Overwrite filename if nipype config is set - filename = config.get('monitoring', 'summary_file', filename) + filename = config.get("monitoring", "summary_file", filename) # If filename still does not make sense, store in $PWD if not filename: - filename = os.path.join(os.getcwd(), 'resource_monitor.json') + filename = os.path.join(os.getcwd(), "resource_monitor.json") if append is None: - append = str2bool(config.get('monitoring', 'summary_append', 'true')) + append = str2bool(config.get("monitoring", "summary_append", "true")) big_dict = { - 'time': [], - 'name': [], - 'interface': [], - 'rss_GiB': [], - 'vms_GiB': [], - 'cpus': [], - 'mapnode': [], - 'params': [], + "time": [], + "name": [], + "interface": [], + "rss_GiB": [], + "vms_GiB": [], + "cpus": [], + "mapnode": [], + "params": [], } # If file exists, just append new profile information # If we append different runs, then we will see different # "bursts" of timestamps corresponding to those executions. if append and os.path.isfile(filename): - with open(filename, 'r' if PY3 else 'rb') as rsf: + with open(filename, "r" if PY3 else "rb") as rsf: big_dict = json.load(rsf) for _, node in enumerate(graph.nodes()): nodename = node.fullname classname = node.interface.__class__.__name__ - params = '' + params = "" if node.parameterization: - params = '_'.join(['{}'.format(p) for p in node.parameterization]) + params = "_".join(["{}".format(p) for p in node.parameterization]) try: rt_list = node.result.runtime except Exception: - logger.warning('Could not access runtime info for node %s' - ' (%s interface)', nodename, classname) + logger.warning( + "Could not access runtime info for node %s" " (%s interface)", + nodename, + classname, + ) continue if not isinstance(rt_list, list): @@ -1643,22 +1723,26 @@ def write_workflow_resources(graph, filename=None, append=None): for subidx, runtime in enumerate(rt_list): try: - nsamples = len(runtime.prof_dict['time']) + nsamples = len(runtime.prof_dict["time"]) except AttributeError: logger.warning( 'Could not retrieve profiling information for node "%s" ' - '(mapflow %d/%d).', nodename, subidx + 1, len(rt_list)) + "(mapflow %d/%d).", + nodename, + subidx + 1, + len(rt_list), + ) continue - for key in ['time', 'cpus', 'rss_GiB', 'vms_GiB']: + for key in ["time", "cpus", "rss_GiB", "vms_GiB"]: big_dict[key] += runtime.prof_dict[key] - big_dict['interface'] += [classname] * nsamples - big_dict['name'] += [nodename] * nsamples - big_dict['mapnode'] += [subidx] * nsamples - big_dict['params'] += [params] * nsamples + big_dict["interface"] += [classname] * nsamples + big_dict["name"] += [nodename] * nsamples + big_dict["mapnode"] += [subidx] * nsamples + big_dict["params"] += [params] * nsamples - with open(filename, 'w' if PY3 else 'wb') as rsf: + with open(filename, "w" if PY3 else "wb") as rsf: json.dump(big_dict, rsf, ensure_ascii=False) return filename @@ -1668,6 +1752,7 @@ def topological_sort(graph, depth_first=False): """Returns a depth first sorted order if depth_first is True """ import networkx as nx + nodesort = list(nx.topological_sort(graph)) if not depth_first: return nodesort, None @@ -1685,8 +1770,8 @@ def topological_sort(graph, depth_first=False): for node in desc: indices.append(nodesort.index(node)) nodes.extend( - np.array(nodesort)[np.array(indices)[np.argsort(indices)]] - .tolist()) + np.array(nodesort)[np.array(indices)[np.argsort(indices)]].tolist() + ) for node in desc: nodesort.remove(node) groups.extend([group] * len(desc)) diff --git a/nipype/utils/filemanip.py b/nipype/utils/filemanip.py index d846ce4bca..6897beb19f 100644 --- a/nipype/utils/filemanip.py +++ b/nipype/utils/filemanip.py @@ -3,8 +3,7 @@ # vi: set ft=python sts=4 ts=4 sw=4 et: """Miscellaneous file manipulation functions """ -from __future__ import (print_function, division, unicode_literals, - absolute_import) +from __future__ import print_function, division, unicode_literals, absolute_import import sys import pickle @@ -21,35 +20,34 @@ import contextlib import posixpath import simplejson as json -from filelock import SoftFileLock +from time import sleep, time from builtins import str, bytes, open -from .. import logging, config +from .. import logging, config, __version__ as version from .misc import is_container from future import standard_library + standard_library.install_aliases() -fmlogger = logging.getLogger('nipype.utils') +fmlogger = logging.getLogger("nipype.utils") -related_filetype_sets = [ - ('.hdr', '.img', '.mat'), - ('.nii', '.mat'), - ('.BRIK', '.HEAD'), -] +related_filetype_sets = [(".hdr", ".img", ".mat"), (".nii", ".mat"), (".BRIK", ".HEAD")] PY3 = sys.version_info[0] >= 3 try: from builtins import FileNotFoundError, FileExistsError except ImportError: # PY27 + class FileNotFoundError(OSError): # noqa """Defines the exception for Python 2.""" def __init__(self, path): """Initialize the exception.""" super(FileNotFoundError, self).__init__( - 2, 'No such file or directory', '%s' % path) + 2, "No such file or directory", "%s" % path + ) class FileExistsError(OSError): # noqa """Defines the exception for Python 2.""" @@ -57,7 +55,8 @@ class FileExistsError(OSError): # noqa def __init__(self, path): """Initialize the exception.""" super(FileExistsError, self).__init__( - 17, 'File or directory exists', '%s' % path) + 17, "File or directory exists", "%s" % path + ) USING_PATHLIB2 = False @@ -65,6 +64,7 @@ def __init__(self, path): from pathlib import Path except ImportError: from pathlib2 import Path + USING_PATHLIB2 = True @@ -104,11 +104,12 @@ def path_mkdir(path, mode=0o777, parents=False, exist_ok=False): return os.mkdir(str(path), mode=mode) -if not hasattr(Path, 'write_text'): +if not hasattr(Path, "write_text"): # PY34 - Path does not have write_text def _write_text(self, text): - with open(str(self), 'w') as f: + with open(str(self), "w") as f: f.write(text) + Path.write_text = _write_text @@ -152,8 +153,7 @@ def split_filename(fname): ext = None for special_ext in special_extensions: ext_len = len(special_ext) - if (len(fname) > ext_len) and \ - (fname[-ext_len:].lower() == special_ext.lower()): + if (len(fname) > ext_len) and (fname[-ext_len:].lower() == special_ext.lower()): ext = fname[-ext_len:] fname = fname[:-ext_len] break @@ -181,11 +181,11 @@ def to_str_py27(value): """ if isinstance(value, dict): - entry = '{}: {}'.format - retval = '{' + entry = "{}: {}".format + retval = "{" for key, val in list(value.items()): if len(retval) > 1: - retval += ', ' + retval += ", " kenc = repr(key) if kenc.startswith(("u'", 'u"')): kenc = kenc[1:] @@ -193,12 +193,12 @@ def to_str_py27(value): if venc.startswith(("u'", 'u"')): venc = venc[1:] retval += entry(kenc, venc) - retval += '}' + retval += "}" return retval istuple = isinstance(value, tuple) if isinstance(value, (tuple, list)): - retval = '(' if istuple else '[' + retval = "(" if istuple else "[" nels = len(value) for i, v in enumerate(value): venc = to_str_py27(v) @@ -207,11 +207,11 @@ def to_str_py27(value): retval += venc if i < nels - 1: - retval += ', ' + retval += ", " if istuple and nels == 1: - retval += ',' - retval += ')' if istuple else ']' + retval += "," + retval += ")" if istuple else "]" return retval retval = repr(value).decode() @@ -220,7 +220,7 @@ def to_str_py27(value): return retval -def fname_presuffix(fname, prefix='', suffix='', newpath=None, use_ext=True): +def fname_presuffix(fname, prefix="", suffix="", newpath=None, use_ext=True): """Manipulates path and name of input filename Parameters @@ -254,7 +254,7 @@ def fname_presuffix(fname, prefix='', suffix='', newpath=None, use_ext=True): """ pth, fname, ext = split_filename(fname) if not use_ext: - ext = '' + ext = "" # No need for isdefined: bool(Undefined) evaluates to False if newpath: @@ -262,7 +262,7 @@ def fname_presuffix(fname, prefix='', suffix='', newpath=None, use_ext=True): return op.join(pth, prefix + fname + suffix + ext) -def fnames_presuffix(fnames, prefix='', suffix='', newpath=None, use_ext=True): +def fnames_presuffix(fnames, prefix="", suffix="", newpath=None, use_ext=True): """Calls fname_presuffix for a list of files. """ f2 = [] @@ -276,7 +276,7 @@ def hash_rename(filename, hashvalue): and sets path to output_directory """ path, name, ext = split_filename(filename) - newfilename = ''.join((name, '_0x', hashvalue, ext)) + newfilename = "".join((name, "_0x", hashvalue, ext)) return op.join(path, newfilename) @@ -285,15 +285,14 @@ def check_forhash(filename): if isinstance(filename, list): filename = filename[0] path, name = op.split(filename) - if re.search('(_0x[a-z0-9]{32})', name): - hashvalue = re.findall('(_0x[a-z0-9]{32})', name) + if re.search("(_0x[a-z0-9]{32})", name): + hashvalue = re.findall("(_0x[a-z0-9]{32})", name) return True, hashvalue else: return False, None -def hash_infile(afile, chunk_len=8192, crypto=hashlib.md5, - raise_notfound=False): +def hash_infile(afile, chunk_len=8192, crypto=hashlib.md5, raise_notfound=False): """ Computes hash of a file using 'crypto' module @@ -317,7 +316,7 @@ def hash_infile(afile, chunk_len=8192, crypto=hashlib.md5, return None crypto_obj = crypto() - with open(afile, 'rb') as fp: + with open(afile, "rb") as fp: while True: data = fp.read(chunk_len) if not data: @@ -352,19 +351,19 @@ def _parse_mount_table(exit_code, output): # ^^^^ ^^^^^ # OSX mount example: /dev/disk2 on / (hfs, local, journaled) # ^ ^^^ - pattern = re.compile(r'.*? on (/.*?) (?:type |\()([^\s,\)]+)') + pattern = re.compile(r".*? on (/.*?) (?:type |\()([^\s,\)]+)") # Keep line and match for error reporting (match == None on failure) # Ignore empty lines - matches = [(l, pattern.match(l)) - for l in output.strip().splitlines() if l] + matches = [(l, pattern.match(l)) for l in output.strip().splitlines() if l] # (path, fstype) tuples, sorted by path length (longest first) - mount_info = sorted((match.groups() for _, match in matches - if match is not None), - key=lambda x: len(x[0]), reverse=True) - cifs_paths = [path for path, fstype in mount_info - if fstype.lower() == 'cifs'] + mount_info = sorted( + (match.groups() for _, match in matches if match is not None), + key=lambda x: len(x[0]), + reverse=True, + ) + cifs_paths = [path for path, fstype in mount_info if fstype.lower() == "cifs"] # Report failures as warnings for line, match in matches: @@ -372,7 +371,8 @@ def _parse_mount_table(exit_code, output): fmlogger.debug("Cannot parse mount line: '%s'", line) return [ - mount for mount in mount_info + mount + for mount in mount_info if any(mount[0].startswith(path) for path in cifs_paths) ] @@ -410,17 +410,19 @@ def on_cifs(fname): # Only the first match (most recent parent) counts for fspath, fstype in _cifs_table: if fname.startswith(fspath): - return fstype == 'cifs' + return fstype == "cifs" return False -def copyfile(originalfile, - newfile, - copy=False, - create_new=False, - hashmethod=None, - use_hardlink=False, - copy_related_files=True): +def copyfile( + originalfile, + newfile, + copy=False, + create_new=False, + hashmethod=None, + use_hardlink=False, + copy_related_files=True, +): """Copy or link ``originalfile`` to ``newfile``. If ``use_hardlink`` is True, and the file can be hard-linked, then a @@ -457,7 +459,7 @@ def copyfile(originalfile, if create_new: while op.exists(newfile): base, fname, ext = split_filename(newfile) - s = re.search('_c[0-9]{4,4}$', fname) + s = re.search("_c[0-9]{4,4}$", fname) i = 0 if s: i = int(s.group()[2:]) + 1 @@ -467,7 +469,7 @@ def copyfile(originalfile, newfile = base + os.sep + fname + ext if hashmethod is None: - hashmethod = config.get('execution', 'hash_method').lower() + hashmethod = config.get("execution", "hash_method").lower() # Don't try creating symlinks on CIFS if copy is False and on_cifs(newfile): @@ -487,26 +489,33 @@ def copyfile(originalfile, keep = False if op.lexists(newfile): if op.islink(newfile): - if all((os.readlink(newfile) == op.realpath(originalfile), - not use_hardlink, not copy)): + if all( + ( + os.readlink(newfile) == op.realpath(originalfile), + not use_hardlink, + not copy, + ) + ): keep = True elif posixpath.samefile(newfile, originalfile): keep = True else: - if hashmethod == 'timestamp': + if hashmethod == "timestamp": hashfn = hash_timestamp - elif hashmethod == 'content': + elif hashmethod == "content": hashfn = hash_infile else: raise AttributeError("Unknown hash method found:", hashmethod) newhash = hashfn(newfile) - fmlogger.debug('File: %s already exists,%s, copy:%d', newfile, - newhash, copy) + fmlogger.debug( + "File: %s already exists,%s, copy:%d", newfile, newhash, copy + ) orighash = hashfn(originalfile) keep = newhash == orighash if keep: - fmlogger.debug('File: %s already exists, not overwriting, copy:%d', - newfile, copy) + fmlogger.debug( + "File: %s already exists, not overwriting, copy:%d", newfile, copy + ) else: os.unlink(newfile) @@ -517,7 +526,7 @@ def copyfile(originalfile, # ~hardlink & ~symlink => copy if not keep and use_hardlink: try: - fmlogger.debug('Linking File: %s->%s', newfile, originalfile) + fmlogger.debug("Linking File: %s->%s", newfile, originalfile) # Use realpath to avoid hardlinking symlinks os.link(op.realpath(originalfile), newfile) except OSError: @@ -525,9 +534,9 @@ def copyfile(originalfile, else: keep = True - if not keep and not copy and os.name == 'posix': + if not keep and not copy and os.name == "posix": try: - fmlogger.debug('Symlinking File: %s->%s', newfile, originalfile) + fmlogger.debug("Symlinking File: %s->%s", newfile, originalfile) os.symlink(originalfile, newfile) except OSError: copy = True # Disable symlink for associated files @@ -536,15 +545,17 @@ def copyfile(originalfile, if not keep: try: - fmlogger.debug('Copying File: %s->%s', newfile, originalfile) + fmlogger.debug("Copying File: %s->%s", newfile, originalfile) shutil.copyfile(originalfile, newfile) except shutil.Error as e: fmlogger.warning(e.message) # Associated files if copy_related_files: - related_file_pairs = (get_related_files(f, include_this_file=False) - for f in (originalfile, newfile)) + related_file_pairs = ( + get_related_files(f, include_this_file=False) + for f in (originalfile, newfile) + ) for alt_ofile, alt_nfile in zip(*related_file_pairs): if op.exists(alt_ofile): copyfile( @@ -553,7 +564,8 @@ def copyfile(originalfile, copy, hashmethod=hashmethod, use_hardlink=use_hardlink, - copy_related_files=False) + copy_related_files=False, + ) return newfile @@ -606,9 +618,7 @@ def copyfiles(filelist, dest, copy=False, create_new=False): newfiles = [] for i, f in enumerate(ensure_list(filelist)): if isinstance(f, list): - newfiles.insert(i, - copyfiles( - f, dest, copy=copy, create_new=create_new)) + newfiles.insert(i, copyfiles(f, dest, copy=copy, create_new=create_new)) else: if len(outfiles) > 1: destfile = outfiles[i] @@ -653,9 +663,9 @@ def check_depends(targets, dependencies): """ tgts = ensure_list(targets) deps = ensure_list(dependencies) - return all(map(op.exists, tgts)) and \ - min(map(op.getmtime, tgts)) > \ - max(list(map(op.getmtime, deps)) + [0]) + return all(map(op.exists, tgts)) and min(map(op.getmtime, tgts)) > max( + list(map(op.getmtime, deps)) + [0] + ) def save_json(filename, data): @@ -669,9 +679,9 @@ def save_json(filename, data): Dictionary to save in json file. """ - mode = 'w' + mode = "w" if sys.version_info[0] < 3: - mode = 'wb' + mode = "wb" with open(filename, mode) as fp: json.dump(data, fp, sort_keys=True, indent=4) @@ -690,32 +700,48 @@ def load_json(filename): """ - with open(filename, 'r') as fp: + with open(filename, "r") as fp: data = json.load(fp) return data def loadcrash(infile, *args): - if infile.endswith('pkl') or infile.endswith('pklz'): + if infile.endswith("pkl") or infile.endswith("pklz"): return loadpkl(infile) else: - raise ValueError('Only pickled crashfiles are supported') + raise ValueError("Only pickled crashfiles are supported") def loadpkl(infile): """Load a zipped or plain cPickled file.""" infile = Path(infile) - fmlogger.debug('Loading pkl: %s', infile) - pklopen = gzip.open if infile.suffix == '.pklz' else open - - with SoftFileLock('%s.lock' % infile): - with pklopen(str(infile), 'rb') as pkl_file: - pkl_contents = pkl_file.read() + fmlogger.debug("Loading pkl: %s", infile) + pklopen = gzip.open if infile.suffix == ".pklz" else open + + t = time() + timeout = float(config.get("execution", "job_finished_timeout")) + timed_out = True + while (time() - t) < timeout: + if infile.exists(): + timed_out = False + break + fmlogger.debug("'{}' missing; waiting 2s".format(infile)) + sleep(2) + if timed_out: + error_message = ( + "Result file {0} expected, but " + "does not exist after ({1}) " + "seconds.".format(infile, timeout) + ) + raise IOError(error_message) + + with pklopen(str(infile), "rb") as pkl_file: + pkl_contents = pkl_file.read() pkl_metadata = None # Look if pkl file contains version metadata - idx = pkl_contents.find(b'\n') + idx = pkl_contents.find(b"\n") if idx >= 0: try: pkl_metadata = json.loads(pkl_contents[:idx]) @@ -724,7 +750,7 @@ def loadpkl(infile): pass else: # On success, skip JSON metadata - pkl_contents = pkl_contents[idx + 1:] + pkl_contents = pkl_contents[idx + 1 :] # Pickle files may contain relative paths that must be resolved relative # to the working directory, so use indirectory while attempting to load @@ -735,38 +761,45 @@ def loadpkl(infile): except UnicodeDecodeError: # Was this pickle created with Python 2.x? with indirectory(infile.parent): - unpkl = pickle.loads(pkl_contents, fix_imports=True, encoding='utf-8') - fmlogger.info('Successfully loaded pkl in compatibility mode.') + unpkl = pickle.loads(pkl_contents, fix_imports=True, encoding="utf-8") + fmlogger.info("Successfully loaded pkl in compatibility mode.") # Unpickling problems except Exception as e: - if pkl_metadata and 'version' in pkl_metadata: + if pkl_metadata and "version" in pkl_metadata: from nipype import __version__ as version - if pkl_metadata['version'] != version: - fmlogger.error("""\ + + if pkl_metadata["version"] != version: + fmlogger.error( + """\ Attempted to open a results file generated by Nipype version %s, \ -with an incompatible Nipype version (%s)""", pkl_metadata['version'], version) +with an incompatible Nipype version (%s)""", + pkl_metadata["version"], + version, + ) raise e - fmlogger.warning("""\ + fmlogger.warning( + """\ No metadata was found in the pkl file. Make sure you are currently using \ -the same Nipype version from the generated pkl.""") +the same Nipype version from the generated pkl.""" + ) raise e if unpkl is None: - raise ValueError('Loading %s resulted in None.' % infile) + raise ValueError("Loading %s resulted in None." % infile) return unpkl def crash2txt(filename, record): """ Write out plain text crash file """ - with open(filename, 'w') as fp: - if 'node' in record: - node = record['node'] - fp.write('Node: {}\n'.format(node.fullname)) - fp.write('Working directory: {}\n'.format(node.output_dir())) - fp.write('\n') - fp.write('Node inputs:\n{}\n'.format(node.inputs)) - fp.write(''.join(record['traceback'])) + with open(filename, "w") as fp: + if "node" in record: + node = record["node"] + fp.write("Node: {}\n".format(node.fullname)) + fp.write("Working directory: {}\n".format(node.output_dir())) + fp.write("\n") + fp.write("Node inputs:\n{}\n".format(node.inputs)) + fp.write("".join(record["traceback"])) def read_stream(stream, logger=None, encoding=None): @@ -779,50 +812,53 @@ def read_stream(stream, logger=None, encoding=None): """ - default_encoding = encoding or locale.getdefaultlocale()[1] or 'UTF-8' + default_encoding = encoding or locale.getdefaultlocale()[1] or "UTF-8" logger = logger or fmlogger try: out = stream.decode(default_encoding) except UnicodeDecodeError as err: - out = stream.decode(default_encoding, errors='replace') - logger.warning('Error decoding string: %s', err) + out = stream.decode(default_encoding, errors="replace") + logger.warning("Error decoding string: %s", err) return out.splitlines() def savepkl(filename, record, versioning=False): - pklopen = gzip.open if filename.endswith('.pklz') else open - with SoftFileLock('%s.lock' % filename): - with pklopen(filename, 'wb') as pkl_file: - if versioning: - from nipype import __version__ as version - metadata = json.dumps({'version': version}) + from io import BytesIO - pkl_file.write(metadata.encode('utf-8')) - pkl_file.write('\n'.encode('utf-8')) + with BytesIO() as f: + if versioning: + metadata = json.dumps({"version": version}) + f.write(metadata.encode("utf-8")) + f.write("\n".encode("utf-8")) + pickle.dump(record, f) + content = f.getvalue() - pickle.dump(record, pkl_file) + pkl_open = gzip.open if filename.endswith(".pklz") else open + tmpfile = filename + ".tmp" + with pkl_open(tmpfile, "wb") as pkl_file: + pkl_file.write(content) + os.rename(tmpfile, filename) -rst_levels = ['=', '-', '~', '+'] +rst_levels = ["=", "-", "~", "+"] def write_rst_header(header, level=0): - return '\n'.join( - (header, ''.join([rst_levels[level] for _ in header]))) + '\n\n' + return "\n".join((header, "".join([rst_levels[level] for _ in header]))) + "\n\n" -def write_rst_list(items, prefix=''): +def write_rst_list(items, prefix=""): out = [] for item in items: - out.append('{} {}'.format(prefix, str(item))) - return '\n'.join(out) + '\n\n' + out.append("{} {}".format(prefix, str(item))) + return "\n".join(out) + "\n\n" -def write_rst_dict(info, prefix=''): +def write_rst_dict(info, prefix=""): out = [] for key, value in sorted(info.items()): - out.append('{}* {} : {}'.format(prefix, key, str(value))) - return '\n'.join(out) + '\n\n' + out.append("{}* {} : {}".format(prefix, key, str(value))) + return "\n".join(out) + "\n\n" def dist_is_editable(dist): @@ -836,7 +872,7 @@ def dist_is_editable(dist): # Borrowed from `pip`'s' API """ for path_item in sys.path: - egg_link = op.join(path_item, dist + '.egg-link') + egg_link = op.join(path_item, dist + ".egg-link") if op.isfile(egg_link): return True return False @@ -863,7 +899,7 @@ def makedirs(path, mode=0o777, exist_ok=False): except OSError: fmlogger.debug("Problem creating directory %s", path) if not op.exists(path): - raise OSError('Could not create directory %s' % path) + raise OSError("Could not create directory %s" % path) return path @@ -891,11 +927,12 @@ def emptydirs(path, noexist_ok=False): elcont = os.listdir(path) if ex.errno == errno.ENOTEMPTY and not elcont: fmlogger.warning( - 'An exception was raised trying to remove old %s, but the path' - ' seems empty. Is it an NFS mount?. Passing the exception.', - path) + "An exception was raised trying to remove old %s, but the path" + " seems empty. Is it an NFS mount?. Passing the exception.", + path, + ) elif ex.errno == errno.ENOTEMPTY and elcont: - fmlogger.debug('Folder %s contents (%d items).', path, len(elcont)) + fmlogger.debug("Folder %s contents (%d items).", path, len(elcont)) raise ex else: raise ex @@ -935,11 +972,11 @@ def which(cmd, env=None, pathext=None): """ if pathext is None: - pathext = os.getenv('PATHEXT', '').split(os.pathsep) - pathext.insert(0, '') + pathext = os.getenv("PATHEXT", "").split(os.pathsep) + pathext.insert(0, "") path = os.getenv("PATH", os.defpath) - if env and 'PATH' in env: + if env and "PATH" in env: path = env.get("PATH") if sys.version_info >= (3, 3): @@ -973,27 +1010,25 @@ def get_dependencies(name, environ): """ command = None - if sys.platform == 'darwin': - command = 'otool -L `which %s`' % name - elif 'linux' in sys.platform: - command = 'ldd `which %s`' % name + if sys.platform == "darwin": + command = "otool -L `which %s`" % name + elif "linux" in sys.platform: + command = "ldd `which %s`" % name else: - return 'Platform %s not supported' % sys.platform + return "Platform %s not supported" % sys.platform deps = None try: proc = sp.Popen( - command, - stdout=sp.PIPE, - stderr=sp.PIPE, - shell=True, - env=environ) + command, stdout=sp.PIPE, stderr=sp.PIPE, shell=True, env=environ + ) o, e = proc.communicate() deps = o.rstrip() except Exception as ex: deps = '"%s" failed' % command - fmlogger.warning('Could not get dependencies of %s. Error:\n%s', - name, ex.message) + fmlogger.warning( + "Could not get dependencies of %s. Error:\n%s", name, ex.message + ) return deps @@ -1013,7 +1048,7 @@ def canonicalize_env(env): Windows: environment dictionary with bytes keys and values Other: untouched input ``env`` """ - if os.name != 'nt': + if os.name != "nt": return env # convert unicode to string for python 2 @@ -1022,9 +1057,9 @@ def canonicalize_env(env): out_env = {} for key, val in env.items(): if not isinstance(key, bytes): - key = key.encode('utf-8') + key = key.encode("utf-8") if not isinstance(val, bytes): - val = val.encode('utf-8') + val = val.encode("utf-8") if not PY3: key = bytes_to_native_str(key) val = bytes_to_native_str(val) @@ -1050,11 +1085,13 @@ def relpath(path, start=None): unc_path, rest = op.splitunc(path) unc_start, rest = op.splitunc(start) if bool(unc_path) ^ bool(unc_start): - raise ValueError(("Cannot mix UNC and non-UNC paths " - "(%s and %s)") % (path, start)) + raise ValueError( + ("Cannot mix UNC and non-UNC paths " "(%s and %s)") % (path, start) + ) else: - raise ValueError("path is on drive %s, start on drive %s" % - (path_list[0], start_list[0])) + raise ValueError( + "path is on drive %s, start on drive %s" % (path_list[0], start_list[0]) + ) # Work out how much of the filepath is shared by start and path. for i in range(min(len(start_list), len(path_list))): if start_list[i].lower() != path_list[i].lower():