Source code for capsul.pipeline.custom_nodes.reduce_node

# -*- coding: utf-8 -*-
'''
:class:`ReduceNode`
------------------------
'''

from __future__ import print_function

from __future__ import absolute_import
from capsul.pipeline.pipeline_nodes import Node, Plug
from soma.controller import Controller
import traits.api as traits
import sys
from six.moves import range
from six.moves import zip


[docs] class ReduceNode(Node): ''' Reduce node: converts series of inputs into lists. Typically a series of inputs named ``input_0`` .. ``input_<n>`` will be output as a single list named ``outputs``. Several input series can be handled by the node, and input names can be customized. * The numbers of inputs for each series is given as the ``lengths`` input parameter. It is typically linked from the output of a :class:`~capsul.pipeline.custom_nodes.map_node.MapNode`. * Input parameters names patterns are given as the ``input_names`` parameter. It is a list of patterns, each containing a ``"%d"`` pattern for the input number. The default value is ``['input_%d']``. * Output parameters names are given as the ``output_names`` parameter. The default is ``['outputs']``. ''' _doc_path = 'api/pipeline.html#reducenode' def __init__(self, pipeline, name, input_names=['input_%d'], output_names=['outputs'], input_types=None): in_traits = [{'name': 'lengths', 'optional': True}, {'name': 'skip_empty', 'optional': True}] out_traits = [] if input_types: ptypes = input_types else: ptypes = [traits.File(traits.Undefined, output=False)] \ * len(input_names) self.input_types = ptypes for tr in output_names: out_traits.append({'name': tr, 'optional': False}) super(ReduceNode, self).__init__(pipeline, name, in_traits, out_traits) for tr, ptype in zip(output_names, ptypes): self.add_trait(tr, traits.List(traits.Either(ptype, traits.Undefined), output=True)) self.add_trait('lengths', traits.List(traits.Int(), output=False, desc='lists lengths')) self.add_trait('skip_empty', traits.Bool( False, output=False, desc='remove empty (Undefined, None, empty strings) ' 'from the output lists')) self.input_names = input_names self.output_names = output_names self.lengths = [0] * len(input_names) self.set_callbacks() self.lengths = [1] * len(input_names) def set_callbacks(self): self.on_trait_change(self.resize_callback, 'lengths') self.on_trait_change(self.reduce_callback, 'skip_empty') def resize_callback(self, obj, name, old_value, value): if old_value in (None, traits.Undefined): old_value = [0] * len(self.input_names) if value in (None, traits.Undefined): value = [0] * len(self.input_names) # remove former callback inputs = [] for i, pname_p in enumerate(self.input_names): inputs += [pname_p % j for j in range(old_value[i])] self.on_trait_change(self.reduce_callback, inputs, remove=True) # adjust sizes for in_index in range(len(value)): ptype = self.input_types[in_index] pname_p = self.input_names[in_index] oval = old_value[in_index] val = value[in_index] if oval > val: for i in range(oval - 1, val - 1, -1): pname = pname_p % i self.remove_trait(pname) if pname in self.plugs: # remove links to this plug plug = self.plugs[pname] to_del = [] for link in plug.links_from: linkd = (link[2], link[1], self, pname) to_del.append(linkd) for linkd in to_del: self.pipeline.remove_link(linkd) del self.plugs[pname] for i in range(oval, val): pname = pname_p % i ptype2 = self._clone_trait( ptype, {'output': False, 'optional': False}) self.add_trait(pname, ptype2) plug = Plug(name=pname, optional=False, output=False) self.plugs[pname] = plug plug.on_trait_change( self.pipeline.update_nodes_and_plugs_activation, "enabled") if oval != val: ovalue = [getattr(self, pname_p % i) for i in range(val)] if isinstance(ptype, (traits.Str, traits.File, traits.Directory)): # List trait doesn't accept Undefined as items ovalue = [v # if v not in (None, traits.Undefined) else '' for v in ovalue] setattr(self, self.output_names[in_index], ovalue) # setup new callback inputs = [] for i, pname_p in enumerate(self.input_names): inputs += [pname_p % j for j in range(value[i])] self.on_trait_change(self.reduce_callback, inputs) def reduce_callback(self, obj, name, old_value, value): # find out which input pattern is used in_index = None for index, pname_p in enumerate(self.input_names): for i in range(self.lengths[index]): pname = pname_p % i if name == pname: in_index = index break if in_index is not None: break if in_index is None: in_indices = range(len(self.input_names)) else: in_indices = [in_index] for in_index in in_indices: output = self.output_names[in_index] value = [getattr(self, pname_p % i) for i in range(self.lengths[in_index])] if self.skip_empty: value = [v for v in value if v not in (None, traits.Undefined, '')] elif isinstance(self.input_types[in_index], (traits.Str, traits.File, traits.Directory)): # List trait doesn't accept Undefined as items value = [v # if v not in (None, traits.Undefined) else traits.Undefined for v in value] setattr(self, output, value) def configured_controller(self): c = self.configure_controller() c.input_names = self.input_names c.output_names = self.output_names c.input_types = [(p.trait_type.__class__.__name__ if p.trait_type else p.__class__.__name__) for p in self.input_types] return c @classmethod def configure_controller(cls): c = Controller() c.add_trait('input_types', traits.List(traits.Str)) c.add_trait('input_names', traits.List(traits.Str)) c.add_trait('output_names', traits.List(traits.Str)) c.input_names = ['input_%d'] c.output_names = ['outputs'] c.input_types = ['File'] return c @classmethod def build_node(cls, pipeline, name, conf_controller): t = [] for ptype in conf_controller.input_types: if ptype == 'Str': t.append(traits.Str(traits.Undefined)) elif ptype == 'File': t.append(traits.File(traits.Undefined)) elif ptype not in (None, traits.Undefined, 'None', 'NoneType', '<undefined>'): t.append(getattr(traits, ptype)()) node = ReduceNode(pipeline, name, conf_controller.input_names, conf_controller.output_names, input_types=t) return node def params_to_command(self): return ['custom_job'] def build_job(self, name=None, referenced_input_files=[], referenced_output_files=[], param_dict=None): from soma_workflow.custom_jobs import MapJob param_dict = dict(param_dict) param_dict['input_names'] = self.input_names param_dict['output_names'] = self.output_names param_dict['lengths'] = self.lengths for index, pname_p in enumerate(self.input_names): for i in range(self.lengths[index]): pname = pname_p % i param_dict[pname] = getattr(self, pname) output_name = self.output_names[index] value = getattr(self, output_name) if value not in (None, traits.Undefined): param_dict[output_name] = value job = MapJob(name=name, referenced_input_files=referenced_input_files, referenced_output_files=referenced_output_files, param_dict=param_dict) return job