Source code for capsul.pipeline.custom_nodes.map_node

# -*- coding: utf-8 -*-
'''
:class:`MapNode`
------------------------
'''


from __future__ import absolute_import
from capsul.pipeline.pipeline_nodes import Node, Plug
from soma.controller import Controller
import traits.api as traits
import sys
from six.moves import range
from six.moves import zip


[docs] class MapNode(Node): ''' This node converts lists into series of single items. Typically an input named ``inputs`` is a list of items. The node will separate items and output each of them as an output parameter. The i-th item will be output as ``output_<i>`` by default. The inputs / outputs names can be customized using the constructor parameters ``input_names`` and ``output_names``. Several lists can be split in the same node. The node will also output a ``lengths`` parameter which will contain the input lists lengths. This lengths can typically be input in reduce nodes to perform the reverse operation (see :class:`~capsul.pipeline.custom_nodes.reduce_node.ReduceNode`). * ``input_names`` is a list of input parameters names, each being a list to be split. The default is ``['inputs']`` * ``output_names`` is a list of patterns used to build output parameters names. Each item is a string containing a substitution pattern ``"%d"`` which will be replaced with a number. The default is ``['output_%d']``. Each pattern will be used to replace items from the corresponding input in the same order. Thus ``input_names`` and ``output_names`` should be the same length. ''' _doc_path = 'api/pipeline.html#mapnode' def __init__(self, pipeline, name, input_names=['inputs'], output_names=['output_%d'], input_types=None): in_traits = [] out_traits = [{'name': 'lengths', 'optional': True}] if input_types: ptypes = input_types else: ptypes = [traits.File(traits.Undefined, output=False)] \ * len(input_names) self.input_types = ptypes for tr in input_names: in_traits.append({'name': tr, 'optional': False}) super(MapNode, self).__init__(pipeline, name, in_traits, out_traits) for tr, ptype in zip(input_names, ptypes): self.add_trait(tr, traits.List(ptype, output=False)) self.add_trait('lengths', traits.List(traits.Int(), output=True, optional=True, desc='lists lengths')) self.input_names = input_names self.output_names = output_names self.lengths = [0] * len(input_names) self.set_callbacks() def set_callbacks(self): self.on_trait_change(self.map_callback, self.input_names) def map_callback(self, obj, name, old_value, value): index = self.input_names.index(name) output = self.output_names[index] ptype = self.input_types[index] if old_value in (None, traits.Undefined): old_value = [] if value in (None, traits.Undefined): value = [] if len(old_value) > len(value): for i in range(len(old_value) - 1, len(value) - 1, -1): pname = output % i self.remove_trait(pname) if pname in self.plugs: # remove links to this plug plug = self.plugs[pname] to_del = [] for link in plug.links_to: linkd = (self, pname, link[2], link[1]) to_del.append(linkd) for linkd in to_del: self.pipeline.remove_link(linkd) del self.plugs[pname] for i in range(len(old_value), len(value)): pname = output % i ptype = self._clone_trait(ptype, {'output': True, 'optional': True}) self.add_trait(pname, ptype) plug = Plug(name=output % i, optional=True, output=True) self.plugs[pname] = plug plug.on_trait_change( self.pipeline.update_nodes_and_plugs_activation, "enabled") for i, val in enumerate(value): setattr(self, output % i, val) # update lengths lengths = self.lengths if not isinstance(lengths, list): lengths = [] while len(lengths) <= index: lengths.append(0) lengths[index] = len(value) self.lengths = lengths def configured_controller(self): c = self.configure_controller() c.input_names = self.input_names c.output_names = self.output_names c.input_types = [p.trait_type.__class__.__name__ for p in self.input_types] return c @classmethod def configure_controller(cls): c = Controller() c.add_trait('input_types', traits.List(traits.Str)) c.add_trait('input_names', traits.List(traits.Str)) c.add_trait('output_names', traits.List(traits.Str)) c.input_names = ['inputs'] c.output_names = ['output_%d'] c.input_types = ['File'] return c @classmethod def build_node(cls, pipeline, name, conf_controller): t = [] for ptype in conf_controller.input_types: if ptype == 'Str': t.append(traits.Str(traits.Undefined)) elif ptype == 'File': t.append(traits.File(traits.Undefined)) elif ptype not in (None, traits.Undefined): t.append(getattr(traits, ptype)()) node = MapNode(pipeline, name, conf_controller.input_names, conf_controller.output_names, input_types=t) return node def params_to_command(self): return ['custom_job'] def build_job(self, name=None, referenced_input_files=[], referenced_output_files=[], param_dict=None): from soma_workflow.custom_jobs import MapJob param_dict = dict(param_dict) param_dict['input_names'] = self.input_names param_dict['output_names'] = self.output_names for index, pname in enumerate(self.input_names): value = getattr(self, pname) param_dict[pname] = value output_name = self.output_names[index] if value not in (None, traits.Undefined): for i in range(len(value)): opname = output_name % i param_dict[opname] = getattr(self, opname) job = MapJob(name=name, referenced_input_files=referenced_input_files, referenced_output_files=referenced_output_files, param_dict=param_dict) return job