Source code for capsul.pipeline.process_iteration

# -*- coding: utf-8 -*-
'''
Utility class for iterated nodes in a pipeline. This is mainly internal infrastructure, which a normal programmer should not have to bother about.
A pipeline programmer will not instantiate :class:`ProcessIteration` directly, but rather use the :class:`~capsul.pipeline.pipeline.Pipeline` method :meth:`~capsul.pipeline.pipeline.Pipeline.add_iterative_process`.

Classes
=======
:class:`ProcessIteration`
-------------------------
'''

from __future__ import print_function

from __future__ import absolute_import
import sys
import six
from traits.api import List, Undefined

from capsul.process.process import Process
from capsul.study_config.process_instance import get_process_instance
import capsul.study_config as study_cmod
from traits.api import File, Directory
from six.moves import range

if sys.version_info[0] >= 3:
    xrange = range

[docs] class ProcessIteration(Process): _doc_path = 'api/pipeline.html#processiteration' def __init__(self, process, iterative_parameters, study_config=None, context_name=None): super(ProcessIteration, self).__init__() if self.study_config is None and hasattr(Process, '_study_config'): study_config = study_cmod.default_study_config() if study_config is not None: self.study_config = study_config self.process = get_process_instance(process, study_config=study_config) if context_name is not None: self.process.context_name = context_name self.regular_parameters = set() self.iterative_parameters = set(iterative_parameters) # use the completion system (if any) to get induced (additional) # iterated parameters if study_config is not None: # don't import this at module level to avoid cyclic imports from capsul.attributes.completion_engine \ import ProcessCompletionEngine completion_engine \ = ProcessCompletionEngine.get_completion_engine(self) if hasattr(completion_engine, 'get_induced_iterative_parameters'): induced_iterative_parameters \ = completion_engine.get_induced_iterative_parameters() self.iterative_parameters.update(induced_iterative_parameters) iterative_parameters = self.iterative_parameters # Check that all iterative parameters are valid process parameters user_traits = self.process.user_traits() has_output = False inputs = [] for parameter in self.iterative_parameters: if parameter not in user_traits: raise ValueError('Cannot iterate on parameter %s ' 'that is not a parameter of process %s' % (parameter, self.process.id)) if user_traits[parameter].output: has_output = True else: inputs.append(parameter) # Create iterative process parameters by copying process parameter # and changing iterative parameters to list for name, trait in six.iteritems(user_traits): if name in iterative_parameters: kw = {} if trait.input_filename is False: kw['input_filename'] = False self.add_trait(name, List(trait, output=trait.output, optional=trait.optional, **kw)) if trait.groups: self.trait(name).groups = trait.groups if trait.forbid_completion is not None: # we don't have access to the pipeline or even the # node in self, we cannot propagate the forbid_completion # value outside of self. # However this will be done in Pipeline.add_process() when # inserting self in a pipeline, so this is OK. trait(name).forbid_completion \ = trait.forbid_completion else: self.regular_parameters.add(name) self.add_trait(name, trait) # copy initial value of the underlying process to self # Note: should be this be done via a links system ? setattr(self, name, getattr(self.process, name)) # if the process has iterative outputs, the output lists have to be # resized according to inputs if has_output: self.on_trait_change(self._resize_outputs, inputs) def _resize_outputs(self): num = 0 outputs = [] for param in self.iterative_parameters: if self.process.trait(param).output: if isinstance(self.process.trait(param).trait_type, (File, Directory)): outputs.append(param) else: num = max(num, len(getattr(self, param))) for param in outputs: value = getattr(self, param) mod = False if len(value) > num: new_value = value[:num] mod = True else: if len(value) < num: new_value = value \ + [self.process.trait(param).default] \ * (num - len(value)) mod = True if mod: setattr(self, param, new_value)
[docs] def change_iterative_plug(self, parameter, iterative=None): ''' Change a parameter to be iterative (or non-iterative) Parameters ---------- parameter: str parameter name iterative: bool or None if None, the iterative state will be toggled. If True or False, the parameter state will be set accordingly. ''' if parameter not in self.process.user_traits(): raise ValueError('Cannot iterate on parameter %s ' 'that is not a parameter of process %s' % (parameter, self.process.id)) is_iterative = parameter in self.iterative_parameters if is_iterative == iterative: return # nothing to be done if iterative is None: iterative = not is_iterative trait = self.process.trait(parameter) if iterative: # switch to iterative self.regular_parameters.remove(parameter) self.iterative_parameters.add(parameter) self.remove_trait(parameter) # Create iterative process parameter by copying process parameter # and changing iterative parameter to list self.add_trait(parameter, List(trait, output=trait.output, optional=trait.optional)) if trait.groups: self.trait(parameter).groups = trait.groups # if it is an output, the output list has to be # resized according to inputs if trait.output: inputs = [] for param in self.iterative_parameters: if not self.process.trait(param).output: inputs.append(param) self.on_trait_change(self._resize_outputs, inputs) else: # switch to non-iterative self.remove_trait(parameter) self.iterative_parameters.remove(parameter) self.regular_parameters.add(parameter) self.add_trait(parameter, trait) # copy initial value of the underlying process to self setattr(self, parameter, getattr(self.process, parameter))
def _run_process(self): # Check that all iterative parameter value have the same size no_output_value = None size = None size_error = False for parameter in self.iterative_parameters: trait = self.trait(parameter) value = getattr(self, parameter) psize = len(value) if psize and (not trait.output or len([x for x in value if x in ('', Undefined, None)]) == 0): if size is None: size = psize elif size != psize: size_error = True break if trait.output: if no_output_value is None: no_output_value = False elif no_output_value: size_error = True break else: if trait.output: if no_output_value is None: no_output_value = True elif not no_output_value: size_error = True break else: if size is None: size = psize elif size != psize: size_error = True break if size_error: raise ValueError('Iterative parameter values must be lists of the same size: %s' % ','.join('%s=%d' % (n, len(getattr(self,n))) for n in self.iterative_parameters)) if size == 0: return for parameter in self.regular_parameters: setattr(self.process, parameter, getattr(self, parameter)) if no_output_value: for parameter in self.iterative_parameters: trait = self.trait(parameter) if trait.output: setattr(self, parameter, []) outputs = {} for iteration in range(size): for parameter in self.iterative_parameters: #if not no_output_value or not self.trait(parameter).output: value = getattr(self, parameter) if len(value) > iteration: setattr(self.process, parameter, getattr(self, parameter)[iteration]) # operate completion self.complete_iteration(iteration) self.process() for parameter in self.iterative_parameters: trait = self.trait(parameter) if trait.output: outputs.setdefault(parameter,[]).append( getattr(self.process, parameter)) # reset empty value setattr(self.process, parameter, Undefined) for parameter, value in six.iteritems(outputs): setattr(self, parameter, value) else: for iteration in range(size): for parameter in self.iterative_parameters: setattr(self.process, parameter, getattr(self, parameter)[iteration]) # operate completion self.complete_iteration(iteration) self.process()
[docs] def set_study_config(self, study_config): super(ProcessIteration, self).set_study_config(study_config) self.process.set_study_config(study_config)
def complete_iteration(self, iteration): # don't import this at module level to avoid cyclic imports from capsul.attributes.completion_engine import ProcessCompletionEngine completion_engine = ProcessCompletionEngine.get_completion_engine( self) # check if it is an iterative completion engine if hasattr(completion_engine, 'complete_iteration_step'): completion_engine.complete_iteration_step(iteration)