Source code for nipype.interfaces.spm.preprocess

# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""SPM wrappers for preprocessing data"""

import os
from copy import deepcopy

# Third-party imports
import numpy as np

# Local imports
from ...utils.filemanip import (
    fname_presuffix,
    ensure_list,
    simplify_list,
    split_filename,
)
from ..base import (
    OutputMultiPath,
    TraitedSpec,
    isdefined,
    traits,
    Tuple,
    InputMultiPath,
    InputMultiObject,
    File,
    Str,
)
from .base import (
    SPMCommand,
    scans_for_fname,
    func_is_3d,
    scans_for_fnames,
    SPMCommandInputSpec,
    ImageFileSPM,
)

__docformat__ = "restructuredtext"


class FieldMapInputSpec(SPMCommandInputSpec):
    jobtype = traits.Enum(
        "calculatevdm",
        usedefault=True,
        deprecated="1.9.0",  # Two minor releases in the future
        desc="Must be 'calculatevdm'; to apply VDM, use the ApplyVDM interface.",
    )

    phase_file = File(
        mandatory=True,
        exists=True,
        copyfile=False,
        field="subj.data.presubphasemag.phase",
        desc="presubstracted phase file",
    )
    magnitude_file = File(
        mandatory=True,
        exists=True,
        copyfile=False,
        field="subj.data.presubphasemag.magnitude",
        desc="presubstracted magnitude file",
    )
    echo_times = Tuple(
        traits.Float,
        traits.Float,
        mandatory=True,
        field="subj.defaults.defaultsval.et",
        desc="short and long echo times",
    )
    maskbrain = traits.Bool(
        True,
        usedefault=True,
        field="subj.defaults.defaultsval.maskbrain",
        desc="masking or no masking of the brain",
    )
    blip_direction = traits.Enum(
        1,
        -1,
        mandatory=True,
        field="subj.defaults.defaultsval.blipdir",
        desc="polarity of the phase-encode blips",
    )
    total_readout_time = traits.Float(
        mandatory=True,
        field="subj.defaults.defaultsval.tert",
        desc="total EPI readout time",
    )
    epifm = traits.Bool(
        False,
        usedefault=True,
        field="subj.defaults.defaultsval.epifm",
        desc="epi-based field map",
    )
    jacobian_modulation = traits.Bool(
        False,
        usedefault=True,
        field="subj.defaults.defaultsval.ajm",
        desc="jacobian modulation",
    )
    # Unwarping defaults parameters
    method = traits.Enum(
        "Mark3D",
        "Mark2D",
        "Huttonish",
        usedefault=True,
        desc="One of: Mark3D, Mark2D, Huttonish",
        field="subj.defaults.defaultsval.uflags.method",
    )
    unwarp_fwhm = traits.Range(
        low=0,
        value=10,
        usedefault=True,
        field="subj.defaults.defaultsval.uflags.fwhm",
        desc="gaussian smoothing kernel width",
    )
    pad = traits.Range(
        low=0,
        value=0,
        usedefault=True,
        field="subj.defaults.defaultsval.uflags.pad",
        desc="padding kernel width",
    )
    ws = traits.Bool(
        True,
        usedefault=True,
        field="subj.defaults.defaultsval.uflags.ws",
        desc="weighted smoothing",
    )
    # Brain mask defaults parameters
    template = File(
        copyfile=False,
        exists=True,
        field="subj.defaults.defaultsval.mflags.template",
        desc="template image for brain masking",
    )
    mask_fwhm = traits.Range(
        low=0,
        value=5,
        usedefault=True,
        field="subj.defaults.defaultsval.mflags.fwhm",
        desc="gaussian smoothing kernel width",
    )
    nerode = traits.Range(
        low=0,
        value=2,
        usedefault=True,
        field="subj.defaults.defaultsval.mflags.nerode",
        desc="number of erosions",
    )
    ndilate = traits.Range(
        low=0,
        value=4,
        usedefault=True,
        field="subj.defaults.defaultsval.mflags.ndilate",
        desc="number of erosions",
    )
    thresh = traits.Float(
        0.5,
        usedefault=True,
        field="subj.defaults.defaultsval.mflags.thresh",
        desc="threshold used to create brain mask from segmented data",
    )
    reg = traits.Float(
        0.02,
        usedefault=True,
        field="subj.defaults.defaultsval.mflags.reg",
        desc="regularization value used in the segmentation",
    )
    # EPI unwarping for quality check
    epi_file = File(
        copyfile=False,
        exists=True,
        mandatory=True,
        field="subj.session.epi",
        desc="EPI to unwarp",
    )
    matchvdm = traits.Bool(
        True, usedefault=True, field="subj.matchvdm", desc="match VDM to EPI"
    )
    sessname = Str(
        "_run-", usedefault=True, field="subj.sessname", desc="VDM filename extension"
    )
    writeunwarped = traits.Bool(
        False, usedefault=True, field="subj.writeunwarped", desc="write unwarped EPI"
    )
    anat_file = File(
        copyfile=False,
        exists=True,
        field="subj.anat",
        desc="anatomical image for comparison",
    )
    matchanat = traits.Bool(
        True,
        usedefault=True,
        field="subj.matchanat",
        desc="match anatomical image to EPI",
    )


class FieldMapOutputSpec(TraitedSpec):
    vdm = File(exists=True, desc="voxel difference map")


class FieldMap(SPMCommand):
    """Use the fieldmap toolbox from spm to calculate the voxel displacement map (VDM).

    http://www.fil.ion.ucl.ac.uk/spm/doc/manual.pdf#page=173

    .. important::

        This interface does not deal with real/imag magnitude images nor
        with the two phase files case.

    Examples
    --------
    >>> from nipype.interfaces.spm import FieldMap
    >>> fm = FieldMap()
    >>> fm.inputs.phase_file = 'phase.nii'
    >>> fm.inputs.magnitude_file = 'magnitude.nii'
    >>> fm.inputs.echo_times = (5.19, 7.65)
    >>> fm.inputs.blip_direction = 1
    >>> fm.inputs.total_readout_time = 15.6
    >>> fm.inputs.epi_file = 'epi.nii'
    >>> fm.run() # doctest: +SKIP

    """

    input_spec = FieldMapInputSpec
    output_spec = FieldMapOutputSpec
    _jobtype = "tools"
    _jobname = "fieldmap"

    def _format_arg(self, opt, spec, val):
        """Convert input to appropriate format for spm"""

        if opt in ["phase_file", "magnitude_file", "anat_file", "epi_file"]:
            return scans_for_fname(ensure_list(val))

        return super()._format_arg(opt, spec, val)

    def _parse_inputs(self):
        """validate spm fieldmap options if set to None ignore"""

        einputs = super()._parse_inputs()
        return [{"calculatevdm": einputs[0]}]

    def _list_outputs(self):
        outputs = self._outputs().get()
        jobtype = self.inputs.jobtype

        outputs["vdm"] = fname_presuffix(self.inputs.phase_file, prefix="vdm5_sc")

        return outputs


class ApplyVDMInputSpec(SPMCommandInputSpec):
    in_files = InputMultiObject(
        ImageFileSPM(exists=True),
        field="data.scans",
        mandatory=True,
        copyfile=True,
        desc="list of filenames to apply the vdm to",
    )
    vdmfile = File(
        field="data.vdmfile",
        desc="Voxel displacement map to use",
        mandatory=True,
        copyfile=True,
    )
    distortion_direction = traits.Int(
        2,
        field="roptions.pedir",
        desc="phase encode direction input data have been acquired with",
        usedefault=True,
    )
    write_which = traits.List(
        traits.Int,
        [2, 1],
        field="roptions.which",
        minlen=2,
        maxlen=2,
        usedefault=True,
        desc="If the first value is non-zero, reslice all images. If the second value is non-zero, reslice a mean image.",
    )
    interpolation = traits.Range(
        value=4,
        low=0,
        high=7,
        field="roptions.rinterp",
        desc="degree of b-spline used for interpolation",
    )
    write_wrap = traits.List(
        traits.Int(),
        minlen=3,
        maxlen=3,
        field="roptions.wrap",
        desc=("Check if interpolation should wrap in [x,y,z]"),
    )
    write_mask = traits.Bool(
        field="roptions.mask", desc="True/False mask time series images"
    )
    out_prefix = traits.String(
        "u",
        field="roptions.prefix",
        usedefault=True,
        desc="fieldmap corrected output prefix",
    )


class ApplyVDMOutputSpec(TraitedSpec):
    out_files = OutputMultiPath(
        traits.Either(traits.List(File(exists=True)), File(exists=True)),
        desc=("These will be the fieldmap corrected files."),
    )
    mean_image = File(exists=True, desc="Mean image")


class ApplyVDM(SPMCommand):
    """Use the fieldmap toolbox from spm to apply the voxel displacement map (VDM) to some epi files.

    http://www.fil.ion.ucl.ac.uk/spm/doc/manual.pdf#page=173

    .. important::

        This interface does not deal with real/imag magnitude images nor
        with the two phase files case.

    """

    input_spec = ApplyVDMInputSpec
    output_spec = ApplyVDMOutputSpec
    _jobtype = "tools"
    _jobname = "fieldmap"

    def _format_arg(self, opt, spec, val):
        """Convert input to appropriate format for spm"""

        if opt == 'in_files':
            return scans_for_fnames(
                ensure_list(val), keep4d=False, separate_sessions=False
            )
        if opt == 'vdmfile':
            return scans_for_fname(ensure_list(val))
        return super()._format_arg(opt, spec, val)

    def _parse_inputs(self):
        """validate spm fieldmap options if set to None ignore"""

        einputs = super()._parse_inputs()

        return [{"applyvdm": einputs[0]}]

    def _list_outputs(self):
        outputs = self._outputs().get()
        resliced_all = self.inputs.write_which[0] > 0
        resliced_mean = self.inputs.write_which[1] > 0
        if resliced_mean:
            if isinstance(self.inputs.in_files[0], list):
                first_image = self.inputs.in_files[0][0]
            else:
                first_image = self.inputs.in_files[0]
            outputs["mean_image"] = fname_presuffix(first_image, prefix="meanu")

        if resliced_all:
            outputs["out_files"] = []
            for idx, imgf in enumerate(ensure_list(self.inputs.in_files)):
                appliedvdm_run = []
                if isinstance(imgf, list):
                    for i, inner_imgf in enumerate(ensure_list(imgf)):
                        newfile = fname_presuffix(
                            inner_imgf, prefix=self.inputs.out_prefix
                        )
                        appliedvdm_run.append(newfile)
                else:
                    appliedvdm_run = fname_presuffix(
                        imgf, prefix=self.inputs.out_prefix
                    )
                outputs["out_files"].append(appliedvdm_run)
        return outputs


class SliceTimingInputSpec(SPMCommandInputSpec):
    in_files = InputMultiPath(
        traits.Either(
            traits.List(ImageFileSPM(exists=True)), ImageFileSPM(exists=True)
        ),
        field="scans",
        desc="list of filenames to apply slice timing",
        mandatory=True,
        copyfile=False,
    )
    num_slices = traits.Int(
        field="nslices", desc="number of slices in a volume", mandatory=True
    )
    time_repetition = traits.Float(
        field="tr",
        desc=("time between volume acquisitions (start to start time)"),
        mandatory=True,
    )
    time_acquisition = traits.Float(
        field="ta",
        desc=("time of volume acquisition. usually calculated as TR-(TR/num_slices)"),
        mandatory=True,
    )
    slice_order = traits.List(
        traits.Either(traits.Int(), traits.Float()),
        field="so",
        desc=("1-based order or onset (in ms) in which slices are acquired"),
        mandatory=True,
    )
    ref_slice = traits.Either(
        traits.Int(),
        traits.Float(),
        field="refslice",
        desc="1-based Number of the reference slice or "
        "reference time point if slice_order is in "
        "onsets (ms)",
        mandatory=True,
    )
    out_prefix = traits.String(
        "a", field="prefix", usedefault=True, desc="slicetimed output prefix"
    )


class SliceTimingOutputSpec(TraitedSpec):
    timecorrected_files = OutputMultiPath(
        traits.Either(traits.List(File(exists=True)), File(exists=True)),
        desc="slice time corrected files",
    )


class SliceTiming(SPMCommand):
    """Use spm to perform slice timing correction.

    http://www.fil.ion.ucl.ac.uk/spm/doc/manual.pdf#page=19

    Examples
    --------

    >>> from nipype.interfaces.spm import SliceTiming
    >>> st = SliceTiming()
    >>> st.inputs.in_files = 'functional.nii'
    >>> st.inputs.num_slices = 32
    >>> st.inputs.time_repetition = 6.0
    >>> st.inputs.time_acquisition = 6. - 6./32.
    >>> st.inputs.slice_order = list(range(32,0,-1))
    >>> st.inputs.ref_slice = 1
    >>> st.run() # doctest: +SKIP

    """

    input_spec = SliceTimingInputSpec
    output_spec = SliceTimingOutputSpec

    _jobtype = "temporal"
    _jobname = "st"

    def _format_arg(self, opt, spec, val):
        """Convert input to appropriate format for spm"""
        if opt == "in_files":
            return scans_for_fnames(
                ensure_list(val), keep4d=False, separate_sessions=True
            )
        return super()._format_arg(opt, spec, val)

    def _list_outputs(self):
        outputs = self._outputs().get()
        outputs["timecorrected_files"] = []

        filelist = ensure_list(self.inputs.in_files)
        for f in filelist:
            if isinstance(f, list):
                run = [
                    fname_presuffix(in_f, prefix=self.inputs.out_prefix) for in_f in f
                ]
            else:
                run = fname_presuffix(f, prefix=self.inputs.out_prefix)
            outputs["timecorrected_files"].append(run)
        return outputs


class RealignInputSpec(SPMCommandInputSpec):
    in_files = InputMultiPath(
        traits.Either(
            ImageFileSPM(exists=True), traits.List(ImageFileSPM(exists=True))
        ),
        field="data",
        mandatory=True,
        copyfile=True,
        desc="list of filenames to realign",
    )
    jobtype = traits.Enum(
        "estwrite",
        "estimate",
        "write",
        desc="one of: estimate, write, estwrite",
        usedefault=True,
    )
    quality = traits.Range(
        low=0.0, high=1.0, field="eoptions.quality", desc="0.1 = fast, 1.0 = precise"
    )
    fwhm = traits.Range(
        low=0.0, field="eoptions.fwhm", desc="gaussian smoothing kernel width"
    )
    separation = traits.Range(
        low=0.0, field="eoptions.sep", desc="sampling separation in mm"
    )
    register_to_mean = traits.Bool(
        field="eoptions.rtm",
        desc=("Indicate whether realignment is done to the mean image"),
    )
    weight_img = File(
        exists=True, field="eoptions.weight", desc="filename of weighting image"
    )
    interp = traits.Range(
        low=0,
        high=7,
        field="eoptions.interp",
        desc="degree of b-spline used for interpolation",
    )
    wrap = traits.List(
        traits.Int(),
        minlen=3,
        maxlen=3,
        field="eoptions.wrap",
        desc="Check if interpolation should wrap in [x,y,z]",
    )
    write_which = traits.List(
        traits.Int,
        [2, 1],
        field="roptions.which",
        minlen=2,
        maxlen=2,
        usedefault=True,
        desc="determines which images to reslice",
    )
    write_interp = traits.Range(
        low=0,
        high=7,
        field="roptions.interp",
        desc=("degree of b-spline used for interpolation"),
    )
    write_wrap = traits.List(
        traits.Int(),
        minlen=3,
        maxlen=3,
        field="roptions.wrap",
        desc=("Check if interpolation should wrap in [x,y,z]"),
    )
    write_mask = traits.Bool(field="roptions.mask", desc="True/False mask output image")
    out_prefix = traits.String(
        "r", field="roptions.prefix", usedefault=True, desc="realigned output prefix"
    )


class RealignOutputSpec(TraitedSpec):
    mean_image = File(exists=True, desc="Mean image file from the realignment")
    modified_in_files = OutputMultiPath(
        traits.Either(traits.List(File(exists=True)), File(exists=True)),
        desc=(
            "Copies of all files passed to "
            "in_files. Headers will have "
            "been modified to align all "
            "images with the first, or "
            "optionally to first do that, "
            "extract a mean image, and "
            "re-align to that mean image."
        ),
    )
    realigned_files = OutputMultiPath(
        traits.Either(traits.List(File(exists=True)), File(exists=True)),
        desc=(
            "If jobtype is write or estwrite, "
            "these will be the resliced files."
            " Otherwise, they will be copies "
            "of in_files that have had their "
            "headers rewritten."
        ),
    )
    realignment_parameters = OutputMultiPath(
        File(exists=True), desc=("Estimated translation and rotation parameters")
    )


class Realign(SPMCommand):
    """Use spm_realign for estimating within modality rigid body alignment

    http://www.fil.ion.ucl.ac.uk/spm/doc/manual.pdf#page=25

    Examples
    --------

    >>> import nipype.interfaces.spm as spm
    >>> realign = spm.Realign()
    >>> realign.inputs.in_files = 'functional.nii'
    >>> realign.inputs.register_to_mean = True
    >>> realign.run() # doctest: +SKIP

    """

    input_spec = RealignInputSpec
    output_spec = RealignOutputSpec

    _jobtype = "spatial"
    _jobname = "realign"

    def _format_arg(self, opt, spec, val):
        """Convert input to appropriate format for spm"""
        if opt == "in_files":
            if self.inputs.jobtype == "write":
                separate_sessions = False
            else:
                separate_sessions = True
            return scans_for_fnames(
                val, keep4d=False, separate_sessions=separate_sessions
            )
        return super()._format_arg(opt, spec, val)

    def _parse_inputs(self):
        """validate spm realign options if set to None ignore"""
        einputs = super()._parse_inputs()
        return [{"%s" % (self.inputs.jobtype): einputs[0]}]

    def _list_outputs(self):
        outputs = self._outputs().get()
        resliced_all = self.inputs.write_which[0] > 0
        resliced_mean = self.inputs.write_which[1] > 0

        if self.inputs.jobtype != "write":
            if isdefined(self.inputs.in_files):
                outputs["realignment_parameters"] = []
            for imgf in self.inputs.in_files:
                if isinstance(imgf, list):
                    tmp_imgf = imgf[0]
                else:
                    tmp_imgf = imgf
                outputs["realignment_parameters"].append(
                    fname_presuffix(
                        tmp_imgf, prefix="rp_", suffix=".txt", use_ext=False
                    )
                )
                if not isinstance(imgf, list) and func_is_3d(imgf):
                    break
        if self.inputs.jobtype == "estimate":
            outputs["realigned_files"] = self.inputs.in_files
        if self.inputs.jobtype == "estimate" or self.inputs.jobtype == "estwrite":
            outputs["modified_in_files"] = self.inputs.in_files
        if self.inputs.jobtype == "write" or self.inputs.jobtype == "estwrite":
            if isinstance(self.inputs.in_files[0], list):
                first_image = self.inputs.in_files[0][0]
            else:
                first_image = self.inputs.in_files[0]

            if resliced_mean:
                outputs["mean_image"] = fname_presuffix(first_image, prefix="mean")

            if resliced_all:
                outputs["realigned_files"] = []
                for idx, imgf in enumerate(ensure_list(self.inputs.in_files)):
                    realigned_run = []
                    if isinstance(imgf, list):
                        for i, inner_imgf in enumerate(ensure_list(imgf)):
                            newfile = fname_presuffix(
                                inner_imgf, prefix=self.inputs.out_prefix
                            )
                            realigned_run.append(newfile)
                    else:
                        realigned_run = fname_presuffix(
                            imgf, prefix=self.inputs.out_prefix
                        )
                    outputs["realigned_files"].append(realigned_run)
        return outputs


class RealignUnwarpInputSpec(SPMCommandInputSpec):
    in_files = InputMultiObject(
        traits.Either(
            ImageFileSPM(exists=True), traits.List(ImageFileSPM(exists=True))
        ),
        field="data.scans",
        mandatory=True,
        copyfile=True,
        desc="list of filenames to realign and unwarp",
    )
    phase_map = File(
        field="data.pmscan",
        desc="Voxel displacement map to use in unwarping. Unlike SPM standard "
        "behaviour, the same map will be used for all sessions",
        copyfile=False,
    )
    quality = traits.Range(
        low=0.0, high=1.0, field="eoptions.quality", desc="0.1 = fast, 1.0 = precise"
    )
    fwhm = traits.Range(
        low=0.0, field="eoptions.fwhm", desc="gaussian smoothing kernel width"
    )
    separation = traits.Range(
        low=0.0, field="eoptions.sep", desc="sampling separation in mm"
    )
    register_to_mean = traits.Bool(
        field="eoptions.rtm",
        desc="Indicate whether realignment is done to the mean image",
    )
    weight_img = File(
        exists=True, field="eoptions.weight", desc="filename of weighting image"
    )
    interp = traits.Range(
        low=0,
        high=7,
        field="eoptions.einterp",
        desc="degree of b-spline used for interpolation",
    )
    wrap = traits.List(
        traits.Int(),
        minlen=3,
        maxlen=3,
        field="eoptions.ewrap",
        desc="Check if interpolation should wrap in [x,y,z]",
    )
    est_basis_func = traits.List(
        traits.Int(),
        minlen=2,
        maxlen=2,
        field="uweoptions.basfcn",
        desc="Number of basis functions to use for each dimension",
    )
    est_reg_order = traits.Range(
        low=0,
        high=3,
        field="uweoptions.regorder",
        desc=(
            "This parameter determines how to balance the compromise between likelihood "
            "maximization and smoothness maximization of the estimated field."
        ),
    )
    est_reg_factor = traits.List(
        traits.Int,
        [100000],
        field="uweoptions.lambda",
        minlen=1,
        maxlen=1,
        usedefault=True,
        desc="Regularisation factor. Default: 100000 (medium).",
    )
    est_jacobian_deformations = traits.Bool(
        field="uweoptions.jm",
        desc=(
            "Jacobian deformations. In theory a good idea to include them, "
            " in practice a bad idea. Default: No."
        ),
    )
    est_first_order_effects = traits.List(
        traits.Int(),
        minlen=1,
        maxlen=6,
        field="uweoptions.fot",
        desc="First order effects should only depend on pitch and roll, i.e. [4 5]",
    )
    est_second_order_effects = traits.List(
        traits.Int(),
        minlen=1,
        maxlen=6,
        field="uweoptions.sot",
        desc="List of second order terms to model second derivatives of.",
    )
    est_unwarp_fwhm = traits.Range(
        low=0.0,
        field="uweoptions.uwfwhm",
        desc="gaussian smoothing kernel width for unwarp",
    )
    est_re_est_mov_par = traits.Bool(
        field="uweoptions.rem",
        desc="Re-estimate movement parameters at each unwarping iteration.",
    )
    est_num_of_iterations = traits.List(
        traits.Int,
        [5],
        field="uweoptions.noi",
        minlen=1,
        maxlen=1,
        usedefault=True,
        desc="Number of iterations.",
    )
    est_taylor_expansion_point = traits.String(
        "Average",
        field="uweoptions.expround",
        usedefault=True,
        desc="Point in position space to perform Taylor-expansion around.",
    )
    reslice_which = traits.List(
        traits.Int,
        [2, 1],
        field="uwroptions.uwwhich",
        minlen=2,
        maxlen=2,
        usedefault=True,
        desc="determines which images to reslice",
    )
    reslice_interp = traits.Range(
        low=0,
        high=7,
        field="uwroptions.rinterp",
        desc="degree of b-spline used for interpolation",
    )
    reslice_wrap = traits.List(
        traits.Int(),
        minlen=3,
        maxlen=3,
        field="uwroptions.wrap",
        desc="Check if interpolation should wrap in [x,y,z]",
    )
    reslice_mask = traits.Bool(
        field="uwroptions.mask", desc="True/False mask output image"
    )
    out_prefix = traits.String(
        "u",
        field="uwroptions.prefix",
        usedefault=True,
        desc="realigned and unwarped output prefix",
    )


class RealignUnwarpOutputSpec(TraitedSpec):
    mean_image = File(
        exists=True, desc="Mean image file from the realignment & unwarping"
    )
    modified_in_files = OutputMultiPath(
        traits.Either(traits.List(File(exists=True)), File(exists=True)),
        desc=(
            "Copies of all files passed to "
            "in_files. Headers will have "
            "been modified to align all "
            "images with the first, or "
            "optionally to first do that, "
            "extract a mean image, and "
            "re-align to that mean image."
        ),
    )
    realigned_unwarped_files = OutputMultiPath(
        traits.Either(traits.List(File(exists=True)), File(exists=True)),
        desc="Realigned and unwarped files written to disc.",
    )
    realignment_parameters = OutputMultiPath(
        File(exists=True), desc="Estimated translation and rotation parameters"
    )


class RealignUnwarp(SPMCommand):
    """Use spm_uw_estimate for estimating within subject registration and unwarping
    of time series. Function accepts only one single field map. If in_files is a
    list of files they will be treated as separate sessions but associated to the
    same fieldmap.

    http://www.fil.ion.ucl.ac.uk/spm/doc/manual.pdf#page=31

    Examples
    --------

    >>> import nipype.interfaces.spm as spm
    >>> realignUnwarp = spm.RealignUnwarp()
    >>> realignUnwarp.inputs.in_files = ['functional.nii', 'functional2.nii']
    >>> realignUnwarp.inputs.phase_map = 'voxeldisplacemap.vdm'
    >>> realignUnwarp.inputs.register_to_mean = True
    >>> realignUnwarp.run() # doctest: +SKIP

    """

    input_spec = RealignUnwarpInputSpec
    output_spec = RealignUnwarpOutputSpec

    _jobtype = "spatial"
    _jobname = "realignunwarp"

    def _format_arg(self, opt, spec, val):
        """Convert input to appropriate format for spm"""
        if opt == "in_files":
            return scans_for_fnames(
                ensure_list(val), keep4d=False, separate_sessions=True
            )
        return super()._format_arg(opt, spec, val)

    def _parse_inputs(self, skip=()):
        spmdict = super()._parse_inputs(skip=())[0]

        if isdefined(self.inputs.phase_map):
            pmscan = spmdict["data"]["pmscan"]
        else:
            pmscan = ""

        if isdefined(self.inputs.in_files):
            if isinstance(self.inputs.in_files, list):
                data = [
                    dict(scans=sess, pmscan=pmscan) for sess in spmdict["data"]["scans"]
                ]
            else:
                data = [dict(scans=spmdict["data"]["scans"], pmscan=pmscan)]

        spmdict["data"] = data

        return [spmdict]

    def _list_outputs(self):
        outputs = self._outputs().get()
        resliced_all = self.inputs.reslice_which[0] > 0
        resliced_mean = self.inputs.reslice_which[1] > 0

        if isdefined(self.inputs.in_files):
            outputs["realignment_parameters"] = []
        for imgf in self.inputs.in_files:
            if isinstance(imgf, list):
                tmp_imgf = imgf[0]
            else:
                tmp_imgf = imgf
            outputs["realignment_parameters"].append(
                fname_presuffix(tmp_imgf, prefix="rp_", suffix=".txt", use_ext=False)
            )
            if not isinstance(imgf, list) and func_is_3d(imgf):
                break

        if isinstance(self.inputs.in_files[0], list):
            first_image = self.inputs.in_files[0][0]
        else:
            first_image = self.inputs.in_files[0]

        if resliced_mean:
            outputs["mean_image"] = fname_presuffix(first_image, prefix="meanu")

        if resliced_all:
            outputs["realigned_unwarped_files"] = []
            for idx, imgf in enumerate(ensure_list(self.inputs.in_files)):
                realigned_run = []
                if isinstance(imgf, list):
                    for i, inner_imgf in enumerate(ensure_list(imgf)):
                        newfile = fname_presuffix(
                            inner_imgf, prefix=self.inputs.out_prefix
                        )
                        realigned_run.append(newfile)
                else:
                    realigned_run = fname_presuffix(imgf, prefix=self.inputs.out_prefix)
                outputs["realigned_unwarped_files"].append(realigned_run)
        return outputs


class CoregisterInputSpec(SPMCommandInputSpec):
    target = ImageFileSPM(
        exists=True,
        mandatory=True,
        field="ref",
        desc="reference file to register to",
        copyfile=False,
    )
    source = InputMultiPath(
        ImageFileSPM(exists=True),
        field="source",
        desc="file to register to target",
        copyfile=True,
        mandatory=True,
    )
    jobtype = traits.Enum(
        "estwrite",
        "estimate",
        "write",
        desc="one of: estimate, write, estwrite",
        usedefault=True,
    )
    apply_to_files = InputMultiPath(
        File(exists=True),
        field="other",
        desc="files to apply transformation to",
        copyfile=True,
    )
    cost_function = traits.Enum(
        "mi",
        "nmi",
        "ecc",
        "ncc",
        field="eoptions.cost_fun",
        desc="""cost function, one of:
                    'mi' - Mutual Information,
                    'nmi' - Normalised Mutual Information,
                    'ecc' - Entropy Correlation Coefficient,
                    'ncc' - Normalised Cross Correlation""",
    )
    fwhm = traits.List(
        traits.Float(),
        minlen=2,
        maxlen=2,
        field="eoptions.fwhm",
        desc="gaussian smoothing kernel width (mm)",
    )
    separation = traits.List(
        traits.Float(), field="eoptions.sep", desc="sampling separation in mm"
    )
    tolerance = traits.List(
        traits.Float(),
        field="eoptions.tol",
        desc="acceptable tolerance for each of 12 params",
    )
    write_interp = traits.Range(
        low=0,
        high=7,
        field="roptions.interp",
        desc=("degree of b-spline used for interpolation"),
    )
    write_wrap = traits.List(
        traits.Int(),
        minlen=3,
        maxlen=3,
        field="roptions.wrap",
        desc=("Check if interpolation should wrap in [x,y,z]"),
    )
    write_mask = traits.Bool(field="roptions.mask", desc="True/False mask output image")
    out_prefix = traits.String(
        "r", field="roptions.prefix", usedefault=True, desc="coregistered output prefix"
    )


class CoregisterOutputSpec(TraitedSpec):
    coregistered_source = OutputMultiPath(
        File(exists=True), desc="Coregistered source files"
    )
    coregistered_files = OutputMultiPath(
        File(exists=True), desc="Coregistered other files"
    )


class Coregister(SPMCommand):
    """Use spm_coreg for estimating cross-modality rigid body alignment

    http://www.fil.ion.ucl.ac.uk/spm/doc/manual.pdf#page=39

    Examples
    --------

    >>> import nipype.interfaces.spm as spm
    >>> coreg = spm.Coregister()
    >>> coreg.inputs.target = 'functional.nii'
    >>> coreg.inputs.source = 'structural.nii'
    >>> coreg.run() # doctest: +SKIP

    """

    input_spec = CoregisterInputSpec
    output_spec = CoregisterOutputSpec
    _jobtype = "spatial"
    _jobname = "coreg"

    def _format_arg(self, opt, spec, val):
        """Convert input to appropriate format for spm"""
        if opt == "target" or (opt == "source" and self.inputs.jobtype != "write"):
            return scans_for_fnames(ensure_list(val), keep4d=True)
        if opt == "apply_to_files":
            return np.array(ensure_list(val), dtype=object)
        if opt == "source" and self.inputs.jobtype == "write":
            if isdefined(self.inputs.apply_to_files):
                return scans_for_fnames(val + self.inputs.apply_to_files)
            else:
                return scans_for_fnames(val)
        return super()._format_arg(opt, spec, val)

    def _parse_inputs(self):
        """validate spm coregister options if set to None ignore"""
        if self.inputs.jobtype == "write":
            einputs = super()._parse_inputs(skip=("jobtype", "apply_to_files"))
        else:
            einputs = super()._parse_inputs(skip=("jobtype"))
        jobtype = self.inputs.jobtype
        return [{"%s" % (jobtype): einputs[0]}]

    def _list_outputs(self):
        outputs = self._outputs().get()

        if self.inputs.jobtype == "estimate":
            if isdefined(self.inputs.apply_to_files):
                outputs["coregistered_files"] = self.inputs.apply_to_files
            outputs["coregistered_source"] = self.inputs.source
        elif self.inputs.jobtype == "write" or self.inputs.jobtype == "estwrite":
            if isdefined(self.inputs.apply_to_files):
                outputs["coregistered_files"] = []
                for imgf in ensure_list(self.inputs.apply_to_files):
                    (
                        outputs["coregistered_files"].append(
                            fname_presuffix(imgf, prefix=self.inputs.out_prefix)
                        )
                    )

            outputs["coregistered_source"] = []
            for imgf in ensure_list(self.inputs.source):
                (
                    outputs["coregistered_source"].append(
                        fname_presuffix(imgf, prefix=self.inputs.out_prefix)
                    )
                )

        return outputs


class NormalizeInputSpec(SPMCommandInputSpec):
    template = File(
        exists=True,
        field="eoptions.template",
        desc="template file to normalize to",
        mandatory=True,
        xor=["parameter_file"],
        copyfile=False,
    )
    source = InputMultiPath(
        ImageFileSPM(exists=True),
        field="subj.source",
        xor=["parameter_file"],
        desc="file to normalize to template",
        mandatory=True,
        copyfile=True,
    )
    jobtype = traits.Enum(
        "estwrite", "est", "write", usedefault=True, desc="Estimate, Write or do both"
    )
    apply_to_files = InputMultiPath(
        traits.Either(File(exists=True), traits.List(File(exists=True))),
        field="subj.resample",
        desc="files to apply transformation to",
        copyfile=True,
    )
    parameter_file = File(
        field="subj.matname",
        mandatory=True,
        xor=["source", "template"],
        desc="normalization parameter file*_sn.mat",
        copyfile=False,
    )
    source_weight = File(
        field="subj.wtsrc", desc="name of weighting image for source", copyfile=False
    )
    template_weight = File(
        field="eoptions.weight",
        desc="name of weighting image for template",
        copyfile=False,
    )
    source_image_smoothing = traits.Float(
        field="eoptions.smosrc", desc="source smoothing"
    )
    template_image_smoothing = traits.Float(
        field="eoptions.smoref", desc="template smoothing"
    )
    affine_regularization_type = traits.Enum(
        "mni", "size", "none", field="eoptions.regtype", desc="mni, size, none"
    )
    DCT_period_cutoff = traits.Float(
        field="eoptions.cutoff", desc="Cutoff of for DCT bases"
    )
    nonlinear_iterations = traits.Int(
        field="eoptions.nits", desc=("Number of iterations of nonlinear warping")
    )
    nonlinear_regularization = traits.Float(
        field="eoptions.reg",
        desc=(
            "the amount of the "
            "regularization for the "
            "nonlinear part of the "
            "normalization"
        ),
    )
    write_preserve = traits.Bool(
        field="roptions.preserve", desc="True/False warped images are modulated"
    )
    write_bounding_box = traits.List(
        traits.List(traits.Float(), minlen=3, maxlen=3),
        field="roptions.bb",
        minlen=2,
        maxlen=2,
        desc="3x2-element list of lists",
    )
    write_voxel_sizes = traits.List(
        traits.Float(), field="roptions.vox", minlen=3, maxlen=3, desc="3-element list"
    )
    write_interp = traits.Range(
        low=0,
        high=7,
        field="roptions.interp",
        desc=("degree of b-spline used for interpolation"),
    )
    write_wrap = traits.List(
        traits.Int(),
        field="roptions.wrap",
        desc=("Check if interpolation should wrap in [x,y,z] - list of bools"),
    )
    out_prefix = traits.String(
        "w", field="roptions.prefix", usedefault=True, desc="normalized output prefix"
    )


class NormalizeOutputSpec(TraitedSpec):
    normalization_parameters = OutputMultiPath(
        File(exists=True), desc=("MAT files containing the normalization parameters")
    )
    normalized_source = OutputMultiPath(
        File(exists=True), desc="Normalized source files"
    )
    normalized_files = OutputMultiPath(File(exists=True), desc="Normalized other files")


class Normalize(SPMCommand):
    """use spm_normalise for warping an image to a template

    http://www.fil.ion.ucl.ac.uk/spm/doc/manual.pdf#page=203

    Examples
    --------
    >>> import nipype.interfaces.spm as spm
    >>> norm = spm.Normalize()
    >>> norm.inputs.source = 'functional.nii'
    >>> norm.run() # doctest: +SKIP

    """

    input_spec = NormalizeInputSpec
    output_spec = NormalizeOutputSpec
    _jobtype = "spatial"
    _jobname = "normalise"

    def _format_arg(self, opt, spec, val):
        """Convert input to appropriate format for spm"""
        if opt == "template":
            return scans_for_fname(ensure_list(val))
        if opt == "source":
            return scans_for_fname(ensure_list(val))
        if opt == "apply_to_files":
            return scans_for_fnames(ensure_list(val))
        if opt == "parameter_file":
            return np.array([simplify_list(val)], dtype=object)
        if opt in ["write_wrap"]:
            if len(val) != 3:
                raise ValueError("%s must have 3 elements" % opt)
        return super()._format_arg(opt, spec, val)

    def _parse_inputs(self):
        """Validate spm normalize options if set to None ignore"""
        einputs = super()._parse_inputs(skip=("jobtype", "apply_to_files"))
        if isdefined(self.inputs.apply_to_files):
            inputfiles = deepcopy(self.inputs.apply_to_files)
            if isdefined(self.inputs.source):
                inputfiles.extend(self.inputs.source)
            einputs[0]["subj"]["resample"] = scans_for_fnames(inputfiles)
        jobtype = self.inputs.jobtype
        if jobtype in ["estwrite", "write"]:
            if not isdefined(self.inputs.apply_to_files):
                if isdefined(self.inputs.source):
                    einputs[0]["subj"]["resample"] = scans_for_fname(self.inputs.source)
        return [{"%s" % (jobtype): einputs[0]}]

    def _list_outputs(self):
        outputs = self._outputs().get()

        jobtype = self.inputs.jobtype
        if jobtype.startswith("est"):
            outputs["normalization_parameters"] = []
            for imgf in ensure_list(self.inputs.source):
                outputs["normalization_parameters"].append(
                    fname_presuffix(imgf, suffix="_sn.mat", use_ext=False)
                )
            outputs["normalization_parameters"] = simplify_list(
                outputs["normalization_parameters"]
            )

        if self.inputs.jobtype == "estimate":
            if isdefined(self.inputs.apply_to_files):
                outputs["normalized_files"] = self.inputs.apply_to_files
            outputs["normalized_source"] = self.inputs.source
        elif "write" in self.inputs.jobtype:
            if isdefined(self.inputs.write_preserve) and self.inputs.write_preserve:
                prefixNorm = f"m{self.inputs.out_prefix}"
            else:
                prefixNorm = self.inputs.out_prefix
            outputs["normalized_files"] = []
            if isdefined(self.inputs.apply_to_files):
                filelist = ensure_list(self.inputs.apply_to_files)
                for f in filelist:
                    if isinstance(f, list):
                        run = [fname_presuffix(in_f, prefix=prefixNorm) for in_f in f]
                    else:
                        run = [fname_presuffix(f, prefix=prefixNorm)]
                    outputs["normalized_files"].extend(run)
            if isdefined(self.inputs.source):
                outputs["normalized_source"] = []
                for imgf in ensure_list(self.inputs.source):
                    outputs["normalized_source"].append(
                        fname_presuffix(imgf, prefix=prefixNorm)
                    )

        return outputs


class Normalize12InputSpec(SPMCommandInputSpec):
    image_to_align = ImageFileSPM(
        exists=True,
        field="subj.vol",
        desc=("file to estimate normalization parameters with"),
        xor=["deformation_file"],
        mandatory=True,
        copyfile=True,
    )
    apply_to_files = InputMultiPath(
        traits.Either(
            ImageFileSPM(exists=True), traits.List(ImageFileSPM(exists=True))
        ),
        field="subj.resample",
        desc="files to apply transformation to",
        copyfile=True,
    )
    deformation_file = ImageFileSPM(
        field="subj.def",
        mandatory=True,
        xor=["image_to_align", "tpm"],
        copyfile=False,
        desc=(
            "file y_*.nii containing 3 deformation "
            "fields for the deformation in x, y and z "
            "dimension"
        ),
    )
    jobtype = traits.Enum(
        "estwrite", "est", "write", usedefault=True, desc="Estimate, Write or do Both"
    )
    bias_regularization = traits.Enum(
        0,
        0.00001,
        0.0001,
        0.001,
        0.01,
        0.1,
        1,
        10,
        field="eoptions.biasreg",
        desc="no(0) - extremely heavy (10)",
    )
    bias_fwhm = traits.Enum(
        30,
        40,
        50,
        60,
        70,
        80,
        90,
        100,
        110,
        120,
        130,
        140,
        150,
        "Inf",
        field="eoptions.biasfwhm",
        desc="FWHM of Gaussian smoothness of bias",
    )
    tpm = File(
        exists=True,
        field="eoptions.tpm",
        desc=("template in form of tissue probablitiy maps to normalize to"),
        xor=["deformation_file"],
        copyfile=False,
    )
    affine_regularization_type = traits.Enum(
        "mni", "size", "none", field="eoptions.affreg", desc="mni, size, none"
    )
    warping_regularization = traits.List(
        traits.Float(),
        field="eoptions.reg",
        minlen=5,
        maxlen=5,
        desc=("controls balance between parameters and data"),
    )
    smoothness = traits.Float(
        field="eoptions.fwhm",
        desc=("value (in mm) to smooth the data before normalization"),
    )
    sampling_distance = traits.Float(
        field="eoptions.samp",
        desc=("Sampling distance on data for parameter estimation"),
    )
    write_bounding_box = traits.List(
        traits.List(traits.Float(), minlen=3, maxlen=3),
        field="woptions.bb",
        minlen=2,
        maxlen=2,
        desc=(
            "3x2-element list of lists "
            "representing the bounding box "
            "(in mm) to be written"
        ),
    )
    write_voxel_sizes = traits.List(
        traits.Float(),
        field="woptions.vox",
        minlen=3,
        maxlen=3,
        desc=(
            "3-element list representing the "
            "voxel sizes (in mm) of the written "
            "normalised images"
        ),
    )
    write_interp = traits.Range(
        low=0,
        high=7,
        field="woptions.interp",
        desc=("degree of b-spline used for interpolation"),
    )
    out_prefix = traits.String(
        "w", field="woptions.prefix", usedefault=True, desc="Normalized output prefix"
    )


class Normalize12OutputSpec(TraitedSpec):
    deformation_field = OutputMultiPath(
        File(exists=True),
        desc=(
            "NIfTI file containing 3 "
            "deformation fields for the "
            "deformation in x, y and z "
            "dimension"
        ),
    )
    normalized_image = OutputMultiPath(
        File(exists=True), desc=("Normalized file that needed to be aligned")
    )
    normalized_files = OutputMultiPath(File(exists=True), desc="Normalized other files")


class Normalize12(SPMCommand):
    """uses SPM12's new Normalise routine for warping an image to a template.
    Spatial normalisation is now done via the segmentation routine (which was
    known as ``New Segment`` in SPM8). Note that the normalisation in SPM12
    is done towards a file containing multiple tissue probability maps, which
    was not the case in SPM8.

    http://www.fil.ion.ucl.ac.uk/spm/doc/manual.pdf#page=49

    Examples
    --------
    >>> import nipype.interfaces.spm as spm
    >>> norm12 = spm.Normalize12()
    >>> norm12.inputs.image_to_align = 'structural.nii'
    >>> norm12.inputs.apply_to_files = 'functional.nii'
    >>> norm12.run() # doctest: +SKIP

    """

    input_spec = Normalize12InputSpec
    output_spec = Normalize12OutputSpec
    _jobtype = "spatial"
    _jobname = "normalise"

    def _format_arg(self, opt, spec, val):
        """Convert input to appropriate format for spm"""
        if opt == "tpm":
            return scans_for_fname(ensure_list(val))
        if opt == "image_to_align":
            return scans_for_fname(ensure_list(val))
        if opt == "apply_to_files":
            return scans_for_fnames(ensure_list(val))
        if opt == "deformation_file":
            return np.array([simplify_list(val)], dtype=object)
        if opt in ["nonlinear_regularization"]:
            if len(val) != 5:
                raise ValueError("%s must have 5 elements" % opt)
        return super()._format_arg(opt, spec, val)

    def _parse_inputs(self, skip=()):
        """validate spm normalize options if set to None ignore"""
        einputs = super()._parse_inputs(skip=("jobtype", "apply_to_files"))
        if isdefined(self.inputs.apply_to_files):
            inputfiles = deepcopy(self.inputs.apply_to_files)
            if isdefined(self.inputs.image_to_align):
                inputfiles.extend([self.inputs.image_to_align])
            einputs[0]["subj"]["resample"] = scans_for_fnames(inputfiles)
        jobtype = self.inputs.jobtype
        if jobtype in ["estwrite", "write"]:
            if not isdefined(self.inputs.apply_to_files):
                if isdefined(self.inputs.image_to_align):
                    einputs[0]["subj"]["resample"] = scans_for_fname(
                        self.inputs.image_to_align
                    )
        return [{"%s" % (jobtype): einputs[0]}]

    def _list_outputs(self):
        outputs = self._outputs().get()

        jobtype = self.inputs.jobtype
        if jobtype.startswith("est"):
            outputs["deformation_field"] = []
            for imgf in ensure_list(self.inputs.image_to_align):
                outputs["deformation_field"].append(fname_presuffix(imgf, prefix="y_"))
            outputs["deformation_field"] = simplify_list(outputs["deformation_field"])

        if "write" in self.inputs.jobtype:
            outputs["normalized_files"] = []
            if isdefined(self.inputs.apply_to_files):
                filelist = ensure_list(self.inputs.apply_to_files)
                for f in filelist:
                    if isinstance(f, list):
                        run = [
                            fname_presuffix(in_f, prefix=self.inputs.out_prefix)
                            for in_f in f
                        ]
                    else:
                        run = [fname_presuffix(f, prefix=self.inputs.out_prefix)]
                    outputs["normalized_files"].extend(run)
            if isdefined(self.inputs.image_to_align):
                outputs["normalized_image"] = fname_presuffix(
                    self.inputs.image_to_align, prefix=self.inputs.out_prefix
                )

        return outputs


class SegmentInputSpec(SPMCommandInputSpec):
    data = InputMultiPath(
        ImageFileSPM(exists=True),
        field="data",
        desc="one scan per subject",
        copyfile=False,
        mandatory=True,
    )
    gm_output_type = traits.List(
        traits.Bool(),
        minlen=3,
        maxlen=3,
        field="output.GM",
        desc="""Options to produce grey matter images: c1*.img, wc1*.img and mwc1*.img.
            None: [False,False,False],
            Native Space: [False,False,True],
            Unmodulated Normalised: [False,True,False],
            Modulated Normalised: [True,False,False],
            Native + Unmodulated Normalised: [False,True,True],
            Native + Modulated Normalised: [True,False,True],
            Native + Modulated + Unmodulated: [True,True,True],
            Modulated + Unmodulated Normalised: [True,True,False]""",
    )
    wm_output_type = traits.List(
        traits.Bool(),
        minlen=3,
        maxlen=3,
        field="output.WM",
        desc="""
            Options to produce white matter images: c2*.img, wc2*.img and mwc2*.img.
            None: [False,False,False],
            Native Space: [False,False,True],
            Unmodulated Normalised: [False,True,False],
            Modulated Normalised: [True,False,False],
            Native + Unmodulated Normalised: [False,True,True],
            Native + Modulated Normalised: [True,False,True],
            Native + Modulated + Unmodulated: [True,True,True],
            Modulated + Unmodulated Normalised: [True,True,False]""",
    )
    csf_output_type = traits.List(
        traits.Bool(),
        minlen=3,
        maxlen=3,
        field="output.CSF",
        desc="""
            Options to produce CSF images: c3*.img, wc3*.img and mwc3*.img.
            None: [False,False,False],
            Native Space: [False,False,True],
            Unmodulated Normalised: [False,True,False],
            Modulated Normalised: [True,False,False],
            Native + Unmodulated Normalised: [False,True,True],
            Native + Modulated Normalised: [True,False,True],
            Native + Modulated + Unmodulated: [True,True,True],
            Modulated + Unmodulated Normalised: [True,True,False]""",
    )
    save_bias_corrected = traits.Bool(
        field="output.biascor", desc=("True/False produce a bias corrected image")
    )
    clean_masks = traits.Enum(
        "no",
        "light",
        "thorough",
        field="output.cleanup",
        desc=("clean using estimated brain mask ('no','light','thorough')"),
    )
    tissue_prob_maps = traits.List(
        File(exists=True),
        field="opts.tpm",
        desc=("list of gray, white & csf prob. (opt,)"),
    )
    gaussians_per_class = traits.List(
        traits.Int(),
        field="opts.ngaus",
        desc=("num Gaussians capture intensity distribution"),
    )
    affine_regularization = traits.Enum(
        "mni",
        "eastern",
        "subj",
        "none",
        "",
        field="opts.regtype",
        desc=(
            'Possible options: "mni", '
            '"eastern", "subj", "none" '
            '(no reguralisation), "" '
            "(no affine registration)"
        ),
    )
    warping_regularization = traits.Float(
        field="opts.warpreg", desc=("Controls balance between parameters and data")
    )
    warp_frequency_cutoff = traits.Float(
        field="opts.warpco", desc="Cutoff of DCT bases"
    )
    bias_regularization = traits.Enum(
        0,
        0.00001,
        0.0001,
        0.001,
        0.01,
        0.1,
        1,
        10,
        field="opts.biasreg",
        desc="no(0) - extremely heavy (10)",
    )
    bias_fwhm = traits.Enum(
        30,
        40,
        50,
        60,
        70,
        80,
        90,
        100,
        110,
        120,
        130,
        "Inf",
        field="opts.biasfwhm",
        desc="FWHM of Gaussian smoothness of bias",
    )
    sampling_distance = traits.Float(
        field="opts.samp", desc=("Sampling distance on data for parameter estimation")
    )
    mask_image = File(
        exists=True,
        field="opts.msk",
        desc="Binary image to restrict parameter estimation ",
    )


class SegmentOutputSpec(TraitedSpec):
    native_gm_image = File(desc="native space grey probability map")
    normalized_gm_image = File(desc="normalized grey probability map")
    modulated_gm_image = File(desc=("modulated, normalized grey probability map"))
    native_wm_image = File(desc="native space white probability map")
    normalized_wm_image = File(desc="normalized white probability map")
    modulated_wm_image = File(desc=("modulated, normalized white probability map"))
    native_csf_image = File(desc="native space csf probability map")
    normalized_csf_image = File(desc="normalized csf probability map")
    modulated_csf_image = File(desc=("modulated, normalized csf probability map"))
    modulated_input_image = File(
        deprecated="0.10",
        new_name="bias_corrected_image",
        desc="bias-corrected version of input image",
    )
    bias_corrected_image = File(desc="bias-corrected version of input image")
    transformation_mat = File(exists=True, desc="Normalization transformation")
    inverse_transformation_mat = File(exists=True, desc="Inverse normalization info")


class Segment(SPMCommand):
    """use spm_segment to separate structural images into different
    tissue classes.

    http://www.fil.ion.ucl.ac.uk/spm/doc/manual.pdf#page=209

    Examples
    --------
    >>> import nipype.interfaces.spm as spm
    >>> seg = spm.Segment()
    >>> seg.inputs.data = 'structural.nii'
    >>> seg.run() # doctest: +SKIP

    """

    input_spec = SegmentInputSpec
    output_spec = SegmentOutputSpec

    def __init__(self, **inputs):
        _local_version = SPMCommand().version
        if _local_version and "12." in _local_version:
            self._jobtype = "tools"
            self._jobname = "oldseg"
        else:
            self._jobtype = "spatial"
            self._jobname = "preproc"

        SPMCommand.__init__(self, **inputs)

    def _format_arg(self, opt, spec, val):
        """Convert input to appropriate format for spm"""
        clean_masks_dict = {"no": 0, "light": 1, "thorough": 2}

        if opt in ["data", "tissue_prob_maps"]:
            if isinstance(val, list):
                return scans_for_fnames(val)
            else:
                return scans_for_fname(val)
        if "output_type" in opt:
            return [int(v) for v in val]
        if opt == "mask_image":
            return scans_for_fname(val)
        if opt == "clean_masks":
            return clean_masks_dict[val]
        return super()._format_arg(opt, spec, val)

    def _list_outputs(self):
        outputs = self._outputs().get()
        f = self.inputs.data[0]

        for tidx, tissue in enumerate(["gm", "wm", "csf"]):
            outtype = "%s_output_type" % tissue
            if isdefined(getattr(self.inputs, outtype)):
                for idx, (image, prefix) in enumerate(
                    [("modulated", "mw"), ("normalized", "w"), ("native", "")]
                ):
                    if getattr(self.inputs, outtype)[idx]:
                        outfield = f"{image}_{tissue}_image"
                        outputs[outfield] = fname_presuffix(
                            f, prefix="%sc%d" % (prefix, tidx + 1)
                        )
        if (
            isdefined(self.inputs.save_bias_corrected)
            and self.inputs.save_bias_corrected
        ):
            outputs["bias_corrected_image"] = fname_presuffix(f, prefix="m")
        t_mat = fname_presuffix(f, suffix="_seg_sn.mat", use_ext=False)
        outputs["transformation_mat"] = t_mat
        invt_mat = fname_presuffix(f, suffix="_seg_inv_sn.mat", use_ext=False)
        outputs["inverse_transformation_mat"] = invt_mat
        return outputs


class NewSegmentInputSpec(SPMCommandInputSpec):
    channel_files = InputMultiPath(
        ImageFileSPM(exists=True),
        mandatory=True,
        desc="A list of files to be segmented",
        field="channel",
        copyfile=False,
    )
    channel_info = Tuple(
        traits.Float(),
        traits.Float(),
        Tuple(traits.Bool, traits.Bool),
        desc="""A tuple with the following fields:
            - bias reguralisation (0-10)
            - FWHM of Gaussian smoothness of bias
            - which maps to save (Field, Corrected) - a tuple of two boolean values""",
        field="channel",
    )
    tissues = traits.List(
        Tuple(
            Tuple(ImageFileSPM(exists=True), traits.Int()),
            traits.Int(),
            Tuple(traits.Bool, traits.Bool),
            Tuple(traits.Bool, traits.Bool),
        ),
        desc="""A list of tuples (one per tissue) with the following fields:
            - tissue probability map (4D), 1-based index to frame
            - number of gaussians
            - which maps to save [Native, DARTEL] - a tuple of two boolean values
            - which maps to save [Unmodulated, Modulated] - a tuple of two boolean values""",
        field="tissue",
    )
    affine_regularization = traits.Enum(
        "mni",
        "eastern",
        "subj",
        "none",
        field="warp.affreg",
        desc="mni, eastern, subj, none ",
    )
    warping_regularization = traits.Either(
        traits.List(traits.Float(), minlen=5, maxlen=5),
        traits.Float(),
        field="warp.reg",
        desc=(
            "Warping regularization "
            "parameter(s). Accepts float "
            "or list of floats (the "
            "latter is required by "
            "SPM12)"
        ),
    )
    sampling_distance = traits.Float(
        field="warp.samp", desc=("Sampling distance on data for parameter estimation")
    )
    write_deformation_fields = traits.List(
        traits.Bool(),
        minlen=2,
        maxlen=2,
        field="warp.write",
        desc=("Which deformation fields to write:[Inverse, Forward]"),
    )


class NewSegmentOutputSpec(TraitedSpec):
    native_class_images = traits.List(
        traits.List(File(exists=True)), desc="native space probability maps"
    )
    dartel_input_images = traits.List(
        traits.List(File(exists=True)), desc="dartel imported class images"
    )
    normalized_class_images = traits.List(
        traits.List(File(exists=True)), desc="normalized class images"
    )
    modulated_class_images = traits.List(
        traits.List(File(exists=True)), desc=("modulated+normalized class images")
    )
    transformation_mat = OutputMultiPath(
        File(exists=True), desc="Normalization transformation"
    )
    bias_corrected_images = OutputMultiPath(
        File(exists=True), desc="bias corrected images"
    )
    bias_field_images = OutputMultiPath(File(exists=True), desc="bias field images")
    forward_deformation_field = OutputMultiPath(File(exists=True))
    inverse_deformation_field = OutputMultiPath(File(exists=True))


class NewSegment(SPMCommand):
    """Use spm_preproc8 (New Segment) to separate structural images into
    different tissue classes. Supports multiple modalities.

    NOTE: This interface currently supports single channel input only

    http://www.fil.ion.ucl.ac.uk/spm/doc/manual.pdf#page=43

    Examples
    --------
    >>> import nipype.interfaces.spm as spm
    >>> seg = spm.NewSegment()
    >>> seg.inputs.channel_files = 'structural.nii'
    >>> seg.inputs.channel_info = (0.0001, 60, (True, True))
    >>> seg.run() # doctest: +SKIP

    For VBM pre-processing [http://www.fil.ion.ucl.ac.uk/~john/misc/VBMclass10.pdf],
    TPM.nii should be replaced by /path/to/spm8/toolbox/Seg/TPM.nii

    >>> seg = NewSegment()
    >>> seg.inputs.channel_files = 'structural.nii'
    >>> tissue1 = (('TPM.nii', 1), 2, (True,True), (False, False))
    >>> tissue2 = (('TPM.nii', 2), 2, (True,True), (False, False))
    >>> tissue3 = (('TPM.nii', 3), 2, (True,False), (False, False))
    >>> tissue4 = (('TPM.nii', 4), 2, (False,False), (False, False))
    >>> tissue5 = (('TPM.nii', 5), 2, (False,False), (False, False))
    >>> seg.inputs.tissues = [tissue1, tissue2, tissue3, tissue4, tissue5]
    >>> seg.run() # doctest: +SKIP

    """

    input_spec = NewSegmentInputSpec
    output_spec = NewSegmentOutputSpec

    def __init__(self, **inputs):
        _local_version = SPMCommand().version
        if _local_version and "12." in _local_version:
            self._jobtype = "spatial"
            self._jobname = "preproc"
        else:
            self._jobtype = "tools"
            self._jobname = "preproc8"

        SPMCommand.__init__(self, **inputs)

    def _format_arg(self, opt, spec, val):
        """Convert input to appropriate format for spm"""

        if opt in ["channel_files", "channel_info"]:
            # structure have to be recreated because of some weird traits error
            new_channel = {}
            new_channel["vols"] = scans_for_fnames(self.inputs.channel_files)
            if isdefined(self.inputs.channel_info):
                info = self.inputs.channel_info
                new_channel["biasreg"] = info[0]
                new_channel["biasfwhm"] = info[1]
                new_channel["write"] = [int(info[2][0]), int(info[2][1])]
            return [new_channel]
        elif opt == "tissues":
            new_tissues = []
            for tissue in val:
                new_tissue = {}
                new_tissue["tpm"] = np.array(
                    [",".join([tissue[0][0], str(tissue[0][1])])], dtype=object
                )
                new_tissue["ngaus"] = tissue[1]
                new_tissue["native"] = [int(tissue[2][0]), int(tissue[2][1])]
                new_tissue["warped"] = [int(tissue[3][0]), int(tissue[3][1])]
                new_tissues.append(new_tissue)
            return new_tissues
        elif opt == "write_deformation_fields":
            return super()._format_arg(opt, spec, [int(val[0]), int(val[1])])
        else:
            return super()._format_arg(opt, spec, val)

    def _list_outputs(self):
        outputs = self._outputs().get()
        outputs["native_class_images"] = []
        outputs["dartel_input_images"] = []
        outputs["normalized_class_images"] = []
        outputs["modulated_class_images"] = []
        outputs["transformation_mat"] = []
        outputs["bias_corrected_images"] = []
        outputs["bias_field_images"] = []
        outputs["inverse_deformation_field"] = []
        outputs["forward_deformation_field"] = []

        n_classes = 5
        if isdefined(self.inputs.tissues):
            n_classes = len(self.inputs.tissues)
        for i in range(n_classes):
            outputs["native_class_images"].append([])
            outputs["dartel_input_images"].append([])
            outputs["normalized_class_images"].append([])
            outputs["modulated_class_images"].append([])

        for filename in self.inputs.channel_files:
            pth, base, ext = split_filename(filename)
            if isdefined(self.inputs.tissues):
                for i, tissue in enumerate(self.inputs.tissues):
                    if tissue[2][0]:
                        outputs["native_class_images"][i].append(
                            os.path.join(pth, "c%d%s.nii" % (i + 1, base))
                        )
                    if tissue[2][1]:
                        outputs["dartel_input_images"][i].append(
                            os.path.join(pth, "rc%d%s.nii" % (i + 1, base))
                        )
                    if tissue[3][0]:
                        outputs["normalized_class_images"][i].append(
                            os.path.join(pth, "wc%d%s.nii" % (i + 1, base))
                        )
                    if tissue[3][1]:
                        outputs["modulated_class_images"][i].append(
                            os.path.join(pth, "mwc%d%s.nii" % (i + 1, base))
                        )
            else:
                for i in range(n_classes):
                    outputs["native_class_images"][i].append(
                        os.path.join(pth, "c%d%s.nii" % (i + 1, base))
                    )
            outputs["transformation_mat"].append(
                os.path.join(pth, "%s_seg8.mat" % base)
            )

            if isdefined(self.inputs.write_deformation_fields):
                if self.inputs.write_deformation_fields[0]:
                    outputs["inverse_deformation_field"].append(
                        os.path.join(pth, "iy_%s.nii" % base)
                    )
                if self.inputs.write_deformation_fields[1]:
                    outputs["forward_deformation_field"].append(
                        os.path.join(pth, "y_%s.nii" % base)
                    )

            if isdefined(self.inputs.channel_info):
                if self.inputs.channel_info[2][0]:
                    outputs["bias_field_images"].append(
                        os.path.join(pth, "BiasField_%s.nii" % (base))
                    )
                if self.inputs.channel_info[2][1]:
                    outputs["bias_corrected_images"].append(
                        os.path.join(pth, "m%s.nii" % (base))
                    )
        return outputs


class MultiChannelNewSegmentInputSpec(SPMCommandInputSpec):
    channels = traits.List(
        Tuple(
            InputMultiPath(
                ImageFileSPM(exists=True),
                mandatory=True,
                desc="A list of files to be segmented",
                field="channel",
                copyfile=False,
            ),
            Tuple(
                traits.Float(),
                traits.Float(),
                Tuple(traits.Bool, traits.Bool),
                desc="""A tuple with the following fields:
                    - bias reguralisation (0-10)
                    - FWHM of Gaussian smoothness of bias
                    - which maps to save (Field, Corrected) - a tuple of two boolean values""",
                field="channel",
            ),
        ),
        desc="""A list of tuples (one per each channel) with the following fields:
            - a list of channel files (only 1rst channel files will be segmented)
            - a tuple with the following channel-specific info fields:
              - bias reguralisation (0-10)
              - FWHM of Gaussian smoothness of bias
              - which maps to save (Field, Corrected) - a tuple of two boolean values""",
        field="channel",
    )
    tissues = traits.List(
        Tuple(
            Tuple(ImageFileSPM(exists=True), traits.Int()),
            traits.Int(),
            Tuple(traits.Bool, traits.Bool),
            Tuple(traits.Bool, traits.Bool),
        ),
        desc="""A list of tuples (one per tissue) with the following fields:
            - tissue probability map (4D), 1-based index to frame
            - number of gaussians
            - which maps to save [Native, DARTEL] - a tuple of two boolean values
            - which maps to save [Unmodulated, Modulated] - a tuple of two boolean values""",
        field="tissue",
    )
    affine_regularization = traits.Enum(
        "mni",
        "eastern",
        "subj",
        "none",
        field="warp.affreg",
        desc="mni, eastern, subj, none ",
    )
    warping_regularization = traits.Either(
        traits.List(traits.Float(), minlen=5, maxlen=5),
        traits.Float(),
        field="warp.reg",
        desc=(
            "Warping regularization "
            "parameter(s). Accepts float "
            "or list of floats (the "
            "latter is required by "
            "SPM12)"
        ),
    )
    sampling_distance = traits.Float(
        field="warp.samp", desc=("Sampling distance on data for parameter estimation")
    )
    write_deformation_fields = traits.List(
        traits.Bool(),
        minlen=2,
        maxlen=2,
        field="warp.write",
        desc=("Which deformation fields to write:[Inverse, Forward]"),
    )


class MultiChannelNewSegmentOutputSpec(TraitedSpec):
    native_class_images = traits.List(
        traits.List(File(exists=True)), desc="native space probability maps"
    )
    dartel_input_images = traits.List(
        traits.List(File(exists=True)), desc="dartel imported class images"
    )
    normalized_class_images = traits.List(
        traits.List(File(exists=True)), desc="normalized class images"
    )
    modulated_class_images = traits.List(
        traits.List(File(exists=True)), desc=("modulated+normalized class images")
    )
    transformation_mat = OutputMultiPath(
        File(exists=True), desc="Normalization transformation"
    )
    bias_corrected_images = OutputMultiPath(
        File(exists=True), desc="bias corrected images"
    )
    bias_field_images = OutputMultiPath(File(exists=True), desc="bias field images")
    forward_deformation_field = OutputMultiPath(File(exists=True))
    inverse_deformation_field = OutputMultiPath(File(exists=True))


class MultiChannelNewSegment(SPMCommand):
    """Use spm_preproc8 (New Segment) to separate structural images into
    different tissue classes. Supports multiple modalities and multichannel inputs.

    http://www.fil.ion.ucl.ac.uk/spm/doc/manual.pdf#page=45

    Examples
    --------
    >>> import nipype.interfaces.spm as spm
    >>> seg = spm.MultiChannelNewSegment()
    >>> seg.inputs.channels = [('structural.nii',(0.0001, 60, (True, True)))]
    >>> seg.run() # doctest: +SKIP

    For VBM pre-processing [http://www.fil.ion.ucl.ac.uk/~john/misc/VBMclass10.pdf],
    TPM.nii should be replaced by /path/to/spm8/toolbox/Seg/TPM.nii

    >>> seg = MultiChannelNewSegment()
    >>> channel1= ('T1.nii',(0.0001, 60, (True, True)))
    >>> channel2= ('T2.nii',(0.0001, 60, (True, True)))
    >>> seg.inputs.channels = [channel1, channel2]
    >>> tissue1 = (('TPM.nii', 1), 2, (True,True), (False, False))
    >>> tissue2 = (('TPM.nii', 2), 2, (True,True), (False, False))
    >>> tissue3 = (('TPM.nii', 3), 2, (True,False), (False, False))
    >>> tissue4 = (('TPM.nii', 4), 2, (False,False), (False, False))
    >>> tissue5 = (('TPM.nii', 5), 2, (False,False), (False, False))
    >>> seg.inputs.tissues = [tissue1, tissue2, tissue3, tissue4, tissue5]
    >>> seg.run() # doctest: +SKIP

    """

    input_spec = MultiChannelNewSegmentInputSpec
    output_spec = MultiChannelNewSegmentOutputSpec

    def __init__(self, **inputs):
        _local_version = SPMCommand().version
        if _local_version and "12." in _local_version:
            self._jobtype = "spatial"
            self._jobname = "preproc"
        else:
            self._jobtype = "tools"
            self._jobname = "preproc8"

        SPMCommand.__init__(self, **inputs)

    def _format_arg(self, opt, spec, val):
        """Convert input to appropriate format for spm"""

        if opt == "channels":
            # structure have to be recreated because of some weird traits error
            new_channels = []
            for channel in val:
                new_channel = {}
                new_channel["vols"] = scans_for_fnames(channel[0])
                if isdefined(channel[1]):
                    info = channel[1]
                    new_channel["biasreg"] = info[0]
                    new_channel["biasfwhm"] = info[1]
                    new_channel["write"] = [int(info[2][0]), int(info[2][1])]
                new_channels.append(new_channel)
            return new_channels
        elif opt == "tissues":
            new_tissues = []
            for tissue in val:
                new_tissue = {}
                new_tissue["tpm"] = np.array(
                    [",".join([tissue[0][0], str(tissue[0][1])])], dtype=object
                )
                new_tissue["ngaus"] = tissue[1]
                new_tissue["native"] = [int(tissue[2][0]), int(tissue[2][1])]
                new_tissue["warped"] = [int(tissue[3][0]), int(tissue[3][1])]
                new_tissues.append(new_tissue)
            return new_tissues
        elif opt == "write_deformation_fields":
            return super()._format_arg(opt, spec, [int(val[0]), int(val[1])])
        else:
            return super()._format_arg(opt, spec, val)

    def _list_outputs(self):
        outputs = self._outputs().get()
        outputs["native_class_images"] = []
        outputs["dartel_input_images"] = []
        outputs["normalized_class_images"] = []
        outputs["modulated_class_images"] = []
        outputs["transformation_mat"] = []
        outputs["bias_corrected_images"] = []
        outputs["bias_field_images"] = []
        outputs["inverse_deformation_field"] = []
        outputs["forward_deformation_field"] = []

        n_classes = 5
        if isdefined(self.inputs.tissues):
            n_classes = len(self.inputs.tissues)
        for i in range(n_classes):
            outputs["native_class_images"].append([])
            outputs["dartel_input_images"].append([])
            outputs["normalized_class_images"].append([])
            outputs["modulated_class_images"].append([])

        # main outputs are generated for the first channel images only
        for filename in self.inputs.channels[0][0]:
            pth, base, ext = split_filename(filename)
            if isdefined(self.inputs.tissues):
                for i, tissue in enumerate(self.inputs.tissues):
                    if tissue[2][0]:
                        outputs["native_class_images"][i].append(
                            os.path.join(pth, "c%d%s.nii" % (i + 1, base))
                        )
                    if tissue[2][1]:
                        outputs["dartel_input_images"][i].append(
                            os.path.join(pth, "rc%d%s.nii" % (i + 1, base))
                        )
                    if tissue[3][0]:
                        outputs["normalized_class_images"][i].append(
                            os.path.join(pth, "wc%d%s.nii" % (i + 1, base))
                        )
                    if tissue[3][1]:
                        outputs["modulated_class_images"][i].append(
                            os.path.join(pth, "mwc%d%s.nii" % (i + 1, base))
                        )
            else:
                for i in range(n_classes):
                    outputs["native_class_images"][i].append(
                        os.path.join(pth, "c%d%s.nii" % (i + 1, base))
                    )
            outputs["transformation_mat"].append(
                os.path.join(pth, "%s_seg8.mat" % base)
            )

            if isdefined(self.inputs.write_deformation_fields):
                if self.inputs.write_deformation_fields[0]:
                    outputs["inverse_deformation_field"].append(
                        os.path.join(pth, "iy_%s.nii" % base)
                    )
                if self.inputs.write_deformation_fields[1]:
                    outputs["forward_deformation_field"].append(
                        os.path.join(pth, "y_%s.nii" % base)
                    )

        # bias field related images are generated for images in all channels
        for channel in self.inputs.channels:
            for filename in channel[0]:
                pth, base, ext = split_filename(filename)
                if isdefined(channel[1]):
                    if channel[1][2][0]:
                        outputs["bias_field_images"].append(
                            os.path.join(pth, "BiasField_%s.nii" % (base))
                        )
                    if channel[1][2][1]:
                        outputs["bias_corrected_images"].append(
                            os.path.join(pth, "m%s.nii" % (base))
                        )
        return outputs


class SmoothInputSpec(SPMCommandInputSpec):
    in_files = InputMultiPath(
        ImageFileSPM(exists=True),
        field="data",
        desc="list of files to smooth",
        mandatory=True,
        copyfile=False,
    )
    fwhm = traits.Either(
        traits.List(traits.Float(), minlen=3, maxlen=3),
        traits.Float(),
        field="fwhm",
        desc="3-list of fwhm for each dimension",
    )
    data_type = traits.Int(field="dtype", desc="Data type of the output images")
    implicit_masking = traits.Bool(
        field="im", desc=("A mask implied by a particular voxel value")
    )
    out_prefix = traits.String(
        "s", field="prefix", usedefault=True, desc="smoothed output prefix"
    )


class SmoothOutputSpec(TraitedSpec):
    smoothed_files = OutputMultiPath(File(exists=True), desc="smoothed files")


[docs] class Smooth(SPMCommand): """Use spm_smooth for 3D Gaussian smoothing of image volumes. http://www.fil.ion.ucl.ac.uk/spm/doc/manual.pdf#page=55 Examples -------- >>> import nipype.interfaces.spm as spm >>> smooth = spm.Smooth() >>> smooth.inputs.in_files = 'functional.nii' >>> smooth.inputs.fwhm = [4, 4, 4] >>> smooth.run() # doctest: +SKIP """ input_spec = SmoothInputSpec output_spec = SmoothOutputSpec _jobtype = "spatial" _jobname = "smooth"
[docs] def _format_arg(self, opt, spec, val): if opt in ["in_files"]: return scans_for_fnames(ensure_list(val)) if opt == "fwhm": if not isinstance(val, list): return [val, val, val] if isinstance(val, list): if len(val) == 1: return [val[0], val[0], val[0]] else: return val return super()._format_arg(opt, spec, val)
[docs] def _list_outputs(self): outputs = self._outputs().get() outputs["smoothed_files"] = [] for imgf in ensure_list(self.inputs.in_files): outputs["smoothed_files"].append( fname_presuffix(imgf, prefix=self.inputs.out_prefix) ) return outputs
class DARTELInputSpec(SPMCommandInputSpec): image_files = traits.List( traits.List(ImageFileSPM(exists=True)), desc="A list of files to be segmented", field="warp.images", copyfile=False, mandatory=True, ) template_prefix = traits.Str( "Template", usedefault=True, field="warp.settings.template", desc="Prefix for template", ) regularization_form = traits.Enum( "Linear", "Membrane", "Bending", field="warp.settings.rform", desc=("Form of regularization energy term"), ) iteration_parameters = traits.List( Tuple( traits.Range(1, 10), Tuple(traits.Float, traits.Float, traits.Float), traits.Enum(1, 2, 4, 8, 16, 32, 64, 128, 256, 512), traits.Enum(0, 0.5, 1, 2, 4, 8, 16, 32), ), minlen=3, maxlen=12, field="warp.settings.param", desc="""\ List of tuples for each iteration * Inner iterations * Regularization parameters * Time points for deformation model * smoothing parameter """, ) optimization_parameters = Tuple( traits.Float, traits.Range(1, 8), traits.Range(1, 8), field="warp.settings.optim", desc="""\ Optimization settings a tuple: * LM regularization * cycles of multigrid solver * relaxation iterations """, ) class DARTELOutputSpec(TraitedSpec): final_template_file = File(exists=True, desc="final DARTEL template") template_files = traits.List( File(exists=True), desc=("Templates from different stages of iteration") ) dartel_flow_fields = traits.List(File(exists=True), desc="DARTEL flow fields") class DARTEL(SPMCommand): """Use spm DARTEL to create a template and flow fields http://www.fil.ion.ucl.ac.uk/spm/doc/manual.pdf#page=185 Examples -------- >>> import nipype.interfaces.spm as spm >>> dartel = spm.DARTEL() >>> dartel.inputs.image_files = [['rc1s1.nii','rc1s2.nii'],['rc2s1.nii', 'rc2s2.nii']] >>> dartel.run() # doctest: +SKIP """ input_spec = DARTELInputSpec output_spec = DARTELOutputSpec _jobtype = "tools" _jobname = "dartel" def _format_arg(self, opt, spec, val): """Convert input to appropriate format for spm""" if opt in ["image_files"]: return scans_for_fnames(val, keep4d=True, separate_sessions=True) elif opt == "regularization_form": mapper = {"Linear": 0, "Membrane": 1, "Bending": 2} return mapper[val] elif opt == "iteration_parameters": params = [] for param in val: new_param = {} new_param["its"] = param[0] new_param["rparam"] = list(param[1]) new_param["K"] = param[2] new_param["slam"] = param[3] params.append(new_param) return params elif opt == "optimization_parameters": new_param = {} new_param["lmreg"] = val[0] new_param["cyc"] = val[1] new_param["its"] = val[2] return [new_param] else: return super()._format_arg(opt, spec, val) def _list_outputs(self): outputs = self._outputs().get() outputs["template_files"] = [] for i in range(6): outputs["template_files"].append( os.path.realpath("%s_%d.nii" % (self.inputs.template_prefix, i + 1)) ) outputs["final_template_file"] = os.path.realpath( "%s_6.nii" % self.inputs.template_prefix ) outputs["dartel_flow_fields"] = [] for filename in self.inputs.image_files[0]: pth, base, ext = split_filename(filename) outputs["dartel_flow_fields"].append( os.path.realpath(f"u_{base}_{self.inputs.template_prefix}{ext}") ) return outputs class DARTELNorm2MNIInputSpec(SPMCommandInputSpec): template_file = ImageFileSPM( exists=True, copyfile=False, mandatory=True, desc="DARTEL template", field="mni_norm.template", ) flowfield_files = InputMultiPath( ImageFileSPM(exists=True), mandatory=True, desc="DARTEL flow fields u_rc1*", field="mni_norm.data.subjs.flowfields", ) apply_to_files = InputMultiPath( ImageFileSPM(exists=True), desc="Files to apply the transform to", field="mni_norm.data.subjs.images", mandatory=True, copyfile=False, ) voxel_size = Tuple( traits.Float, traits.Float, traits.Float, desc="Voxel sizes for output file", field="mni_norm.vox", ) bounding_box = Tuple( traits.Float, traits.Float, traits.Float, traits.Float, traits.Float, traits.Float, desc="Voxel sizes for output file", field="mni_norm.bb", ) modulate = traits.Bool( field="mni_norm.preserve", desc=("Modulate out images - no modulation preserves concentrations"), ) fwhm = traits.Either( traits.List(traits.Float(), minlen=3, maxlen=3), traits.Float(), field="mni_norm.fwhm", desc="3-list of fwhm for each dimension", ) class DARTELNorm2MNIOutputSpec(TraitedSpec): normalized_files = OutputMultiPath( File(exists=True), desc="Normalized files in MNI space" ) normalization_parameter_file = File( exists=True, desc=("Transform parameters to MNI space") ) class DARTELNorm2MNI(SPMCommand): """Use spm DARTEL to normalize data to MNI space http://www.fil.ion.ucl.ac.uk/spm/doc/manual.pdf#page=188 Examples -------- >>> import nipype.interfaces.spm as spm >>> nm = spm.DARTELNorm2MNI() >>> nm.inputs.template_file = 'Template_6.nii' >>> nm.inputs.flowfield_files = ['u_rc1s1_Template.nii', 'u_rc1s3_Template.nii'] >>> nm.inputs.apply_to_files = ['c1s1.nii', 'c1s3.nii'] >>> nm.inputs.modulate = True >>> nm.run() # doctest: +SKIP """ input_spec = DARTELNorm2MNIInputSpec output_spec = DARTELNorm2MNIOutputSpec _jobtype = "tools" _jobname = "dartel" def _format_arg(self, opt, spec, val): """Convert input to appropriate format for spm""" if opt in ["template_file"]: return np.array([val], dtype=object) elif opt in ["flowfield_files"]: return scans_for_fnames(val, keep4d=True) elif opt in ["apply_to_files"]: return scans_for_fnames(val, keep4d=True, separate_sessions=True) elif opt == "voxel_size": return list(val) elif opt == "bounding_box": return list(val) elif opt == "fwhm": if isinstance(val, list): return val else: return [val, val, val] else: return super()._format_arg(opt, spec, val) def _list_outputs(self): outputs = self._outputs().get() pth, base, ext = split_filename(self.inputs.template_file) outputs["normalization_parameter_file"] = os.path.realpath(base + "_2mni.mat") outputs["normalized_files"] = [] prefix = "w" if isdefined(self.inputs.modulate) and self.inputs.modulate: prefix = "m" + prefix if not isdefined(self.inputs.fwhm) or self.inputs.fwhm > 0: prefix = "s" + prefix for filename in self.inputs.apply_to_files: pth, base, ext = split_filename(filename) outputs["normalized_files"].append(os.path.realpath(f"{prefix}{base}{ext}")) return outputs class CreateWarpedInputSpec(SPMCommandInputSpec): image_files = InputMultiPath( ImageFileSPM(exists=True), mandatory=True, desc="A list of files to be warped", field="crt_warped.images", copyfile=False, ) flowfield_files = InputMultiPath( ImageFileSPM(exists=True), copyfile=False, desc="DARTEL flow fields u_rc1*", field="crt_warped.flowfields", mandatory=True, ) iterations = traits.Range( low=0, high=9, desc=("The number of iterations: log2(number of time steps)"), field="crt_warped.K", ) interp = traits.Range( low=0, high=7, field="crt_warped.interp", desc="degree of b-spline used for interpolation", ) modulate = traits.Bool(field="crt_warped.jactransf", desc="Modulate images") class CreateWarpedOutputSpec(TraitedSpec): warped_files = traits.List(File(exists=True, desc="final warped files")) class CreateWarped(SPMCommand): """Apply a flow field estimated by DARTEL to create warped images http://www.fil.ion.ucl.ac.uk/spm/doc/manual.pdf#page=190 Examples -------- >>> import nipype.interfaces.spm as spm >>> create_warped = spm.CreateWarped() >>> create_warped.inputs.image_files = ['rc1s1.nii', 'rc1s2.nii'] >>> create_warped.inputs.flowfield_files = ['u_rc1s1_Template.nii', 'u_rc1s2_Template.nii'] >>> create_warped.run() # doctest: +SKIP """ input_spec = CreateWarpedInputSpec output_spec = CreateWarpedOutputSpec _jobtype = "tools" _jobname = "dartel" def _format_arg(self, opt, spec, val): """Convert input to appropriate format for spm""" if opt in ["image_files"]: return scans_for_fnames(val, keep4d=True, separate_sessions=True) if opt in ["flowfield_files"]: return scans_for_fnames(val, keep4d=True) else: return super()._format_arg(opt, spec, val) def _list_outputs(self): outputs = self._outputs().get() outputs["warped_files"] = [] for filename in self.inputs.image_files: pth, base, ext = split_filename(filename) if isdefined(self.inputs.modulate) and self.inputs.modulate: outputs["warped_files"].append(os.path.realpath(f"mw{base}{ext}")) else: outputs["warped_files"].append(os.path.realpath(f"w{base}{ext}")) return outputs class ApplyDeformationFieldInputSpec(SPMCommandInputSpec): in_files = InputMultiPath(ImageFileSPM(exists=True), mandatory=True, field="fnames") deformation_field = File(exists=True, mandatory=True, field="comp{1}.def") reference_volume = ImageFileSPM( exists=True, mandatory=True, field="comp{2}.id.space" ) interp = traits.Range( low=0, high=7, field="interp", desc="degree of b-spline used for interpolation" ) class ApplyDeformationFieldOutputSpec(TraitedSpec): out_files = OutputMultiPath(File(exists=True)) class ApplyDeformations(SPMCommand): input_spec = ApplyDeformationFieldInputSpec output_spec = ApplyDeformationFieldOutputSpec _jobtype = "util" _jobname = "defs" def _format_arg(self, opt, spec, val): """Convert input to appropriate format for spm""" if opt in ["deformation_field", "reference_volume"]: val = [val] if opt in ["deformation_field"]: return scans_for_fnames(val, keep4d=True, separate_sessions=False) if opt in ["in_files", "reference_volume"]: return scans_for_fnames(val, keep4d=False, separate_sessions=False) else: return super()._format_arg(opt, spec, val) def _list_outputs(self): outputs = self._outputs().get() outputs["out_files"] = [] for filename in self.inputs.in_files: _, fname = os.path.split(filename) outputs["out_files"].append(os.path.realpath("w%s" % fname)) return outputs class VBMSegmentInputSpec(SPMCommandInputSpec): in_files = InputMultiPath( ImageFileSPM(exists=True), desc="A list of files to be segmented", field="estwrite.data", copyfile=False, mandatory=True, ) tissues = ImageFileSPM( exists=True, field="estwrite.tpm", desc="tissue probability map" ) gaussians_per_class = Tuple( (2, 2, 2, 3, 4, 2), *([traits.Int()] * 6), usedefault=True, desc="number of gaussians for each tissue class", ) bias_regularization = traits.Enum( 0.0001, (0, 0.00001, 0.0001, 0.001, 0.01, 0.1, 1, 10), field="estwrite.opts.biasreg", usedefault=True, desc="no(0) - extremely heavy (10)", ) bias_fwhm = traits.Enum( 60, (30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, "Inf"), field="estwrite.opts.biasfwhm", usedefault=True, desc="FWHM of Gaussian smoothness of bias", ) sampling_distance = traits.Float( 3, usedefault=True, field="estwrite.opts.samp", desc="Sampling distance on data for parameter estimation", ) warping_regularization = traits.Float( 4, usedefault=True, field="estwrite.opts.warpreg", desc="Controls balance between parameters and data", ) spatial_normalization = traits.Enum("high", "low", usedefault=True) dartel_template = ImageFileSPM( exists=True, field="estwrite.extopts.dartelwarp.normhigh.darteltpm" ) use_sanlm_denoising_filter = traits.Range( 0, 2, 2, usedefault=True, field="estwrite.extopts.sanlm", desc="0=No denoising, 1=denoising,2=denoising multi-threaded", ) mrf_weighting = traits.Float(0.15, usedefault=True, field="estwrite.extopts.mrf") cleanup_partitions = traits.Int( 1, usedefault=True, field="estwrite.extopts.cleanup", desc="0=None,1=light,2=thorough", ) display_results = traits.Bool(True, usedefault=True, field="estwrite.extopts.print") gm_native = traits.Bool(False, usedefault=True, field="estwrite.output.GM.native") gm_normalized = traits.Bool( False, usedefault=True, field="estwrite.output.GM.warped" ) gm_modulated_normalized = traits.Range( 0, 2, 2, usedefault=True, field="estwrite.output.GM.modulated", desc="0=none,1=affine+non-linear(SPM8 default),2=non-linear only", ) gm_dartel = traits.Range( 0, 2, 0, usedefault=True, field="estwrite.output.GM.dartel", desc="0=None,1=rigid(SPM8 default),2=affine", ) wm_native = traits.Bool(False, usedefault=True, field="estwrite.output.WM.native") wm_normalized = traits.Bool( False, usedefault=True, field="estwrite.output.WM.warped" ) wm_modulated_normalized = traits.Range( 0, 2, 2, usedefault=True, field="estwrite.output.WM.modulated", desc="0=none,1=affine+non-linear(SPM8 default),2=non-linear only", ) wm_dartel = traits.Range( 0, 2, 0, usedefault=True, field="estwrite.output.WM.dartel", desc="0=None,1=rigid(SPM8 default),2=affine", ) csf_native = traits.Bool(False, usedefault=True, field="estwrite.output.CSF.native") csf_normalized = traits.Bool( False, usedefault=True, field="estwrite.output.CSF.warped" ) csf_modulated_normalized = traits.Range( 0, 2, 2, usedefault=True, field="estwrite.output.CSF.modulated", desc="0=none,1=affine+non-linear(SPM8 default),2=non-linear only", ) csf_dartel = traits.Range( 0, 2, 0, usedefault=True, field="estwrite.output.CSF.dartel", desc="0=None,1=rigid(SPM8 default),2=affine", ) bias_corrected_native = traits.Bool( False, usedefault=True, field="estwrite.output.bias.native" ) bias_corrected_normalized = traits.Bool( True, usedefault=True, field="estwrite.output.bias.warped" ) bias_corrected_affine = traits.Bool( False, usedefault=True, field="estwrite.output.bias.affine" ) pve_label_native = traits.Bool( False, usedefault=True, field="estwrite.output.label.native" ) pve_label_normalized = traits.Bool( False, usedefault=True, field="estwrite.output.label.warped" ) pve_label_dartel = traits.Range( 0, 2, 0, usedefault=True, field="estwrite.output.label.dartel", desc="0=None,1=rigid(SPM8 default),2=affine", ) jacobian_determinant = traits.Bool( False, usedefault=True, field="estwrite.jacobian.warped" ) deformation_field = Tuple( (0, 0), traits.Bool, traits.Bool, usedefault=True, field="estwrite.output.warps", desc="forward and inverse field", ) class VBMSegmentOuputSpec(TraitedSpec): native_class_images = traits.List( traits.List(File(exists=True)), desc="native space probability maps" ) dartel_input_images = traits.List( traits.List(File(exists=True)), desc="dartel imported class images" ) normalized_class_images = traits.List( traits.List(File(exists=True)), desc="normalized class images" ) modulated_class_images = traits.List( traits.List(File(exists=True)), desc=("modulated+normalized class images") ) transformation_mat = OutputMultiPath( File(exists=True), desc="Normalization transformation" ) bias_corrected_images = OutputMultiPath( File(exists=True), desc="bias corrected images" ) normalized_bias_corrected_images = OutputMultiPath( File(exists=True), desc="bias corrected images" ) pve_label_native_images = OutputMultiPath(File(exists=True)) pve_label_normalized_images = OutputMultiPath(File(exists=True)) pve_label_registered_images = OutputMultiPath(File(exists=True)) forward_deformation_field = OutputMultiPath(File(exists=True)) inverse_deformation_field = OutputMultiPath(File(exists=True)) jacobian_determinant_images = OutputMultiPath(File(exists=True)) class VBMSegment(SPMCommand): """Use VBM8 toolbox to separate structural images into different tissue classes. Example ------- >>> import nipype.interfaces.spm as spm >>> seg = spm.VBMSegment() >>> seg.inputs.tissues = 'TPM.nii' >>> seg.inputs.dartel_template = 'Template_1_IXI550_MNI152.nii' >>> seg.inputs.bias_corrected_native = True >>> seg.inputs.gm_native = True >>> seg.inputs.wm_native = True >>> seg.inputs.csf_native = True >>> seg.inputs.pve_label_native = True >>> seg.inputs.deformation_field = (True, False) >>> seg.run() # doctest: +SKIP """ input_spec = VBMSegmentInputSpec output_spec = VBMSegmentOuputSpec _jobtype = "tools" _jobname = "vbm8" def _list_outputs(self): outputs = self._outputs().get() do_dartel = self.inputs.spatial_normalization dartel_px = "" if do_dartel: dartel_px = "r" outputs["native_class_images"] = [[], [], []] outputs["dartel_input_images"] = [[], [], []] outputs["normalized_class_images"] = [[], [], []] outputs["modulated_class_images"] = [[], [], []] outputs["transformation_mat"] = [] outputs["bias_corrected_images"] = [] outputs["normalized_bias_corrected_images"] = [] outputs["inverse_deformation_field"] = [] outputs["forward_deformation_field"] = [] outputs["jacobian_determinant_images"] = [] outputs["pve_label_native_images"] = [] outputs["pve_label_normalized_images"] = [] outputs["pve_label_registered_images"] = [] for filename in self.inputs.in_files: pth, base, ext = split_filename(filename) outputs["transformation_mat"].append( os.path.join(pth, "%s_seg8.mat" % base) ) for i, tis in enumerate(["gm", "wm", "csf"]): # native space if getattr(self.inputs, "%s_native" % tis): outputs["native_class_images"][i].append( os.path.join(pth, "p%d%s.nii" % (i + 1, base)) ) if getattr(self.inputs, "%s_dartel" % tis) == 1: outputs["dartel_input_images"][i].append( os.path.join(pth, "rp%d%s.nii" % (i + 1, base)) ) elif getattr(self.inputs, "%s_dartel" % tis) == 2: outputs["dartel_input_images"][i].append( os.path.join(pth, "rp%d%s_affine.nii" % (i + 1, base)) ) # normalized space if getattr(self.inputs, "%s_normalized" % tis): outputs["normalized_class_images"][i].append( os.path.join(pth, "w%sp%d%s.nii" % (dartel_px, i + 1, base)) ) if getattr(self.inputs, "%s_modulated_normalized" % tis) == 1: outputs["modulated_class_images"][i].append( os.path.join(pth, "mw%sp%d%s.nii" % (dartel_px, i + 1, base)) ) elif getattr(self.inputs, "%s_modulated_normalized" % tis) == 2: outputs["normalized_class_images"][i].append( os.path.join(pth, "m0w%sp%d%s.nii" % (dartel_px, i + 1, base)) ) if self.inputs.pve_label_native: outputs["pve_label_native_images"].append( os.path.join(pth, "p0%s.nii" % (base)) ) if self.inputs.pve_label_normalized: outputs["pve_label_normalized_images"].append( os.path.join(pth, f"w{dartel_px}p0{base}.nii") ) if self.inputs.pve_label_dartel == 1: outputs["pve_label_registered_images"].append( os.path.join(pth, "rp0%s.nii" % (base)) ) elif self.inputs.pve_label_dartel == 2: outputs["pve_label_registered_images"].append( os.path.join(pth, "rp0%s_affine.nii" % (base)) ) if self.inputs.bias_corrected_native: outputs["bias_corrected_images"].append( os.path.join(pth, "m%s.nii" % (base)) ) if self.inputs.bias_corrected_normalized: outputs["normalized_bias_corrected_images"].append( os.path.join(pth, f"wm{dartel_px}{base}.nii") ) if self.inputs.deformation_field[0]: outputs["forward_deformation_field"].append( os.path.join(pth, f"y_{dartel_px}{base}.nii") ) if self.inputs.deformation_field[1]: outputs["inverse_deformation_field"].append( os.path.join(pth, f"iy_{dartel_px}{base}.nii") ) if self.inputs.jacobian_determinant and do_dartel: outputs["jacobian_determinant_images"].append( os.path.join(pth, "jac_wrp1%s.nii" % (base)) ) return outputs def _format_arg(self, opt, spec, val): """Convert input to appropriate format for spm""" if opt in ["in_files"]: return scans_for_fnames(val, keep4d=True) elif opt in ["spatial_normalization"]: if val == "low": return {"normlow": []} elif opt in ["dartel_template"]: return np.array([val], dtype=object) elif opt in ["deformation_field"]: return super()._format_arg(opt, spec, [int(val[0]), int(val[1])]) else: return super()._format_arg(opt, spec, val) def _parse_inputs(self): if self.inputs.spatial_normalization == "low": einputs = super()._parse_inputs( skip=("spatial_normalization", "dartel_template") ) einputs[0]["estwrite"]["extopts"]["dartelwarp"] = {"normlow": 1} return einputs else: return super()._parse_inputs(skip=("spatial_normalization"))