Source code for dipy.workflows.io

import importlib
from inspect import getmembers, isfunction
import logging
import os
from os.path import join as pjoin
import sys
import warnings

import numpy as np
import trx.trx_file_memmap as tmm

from dipy.core.sphere import Sphere
from dipy.data import get_sphere
from dipy.io.image import load_nifti, save_nifti
from dipy.io.peaks import (
    load_pam,
    niftis_to_pam,
    pam_to_niftis,
    tensor_to_pam,
)
from dipy.io.streamline import load_tractogram, save_tractogram
from dipy.reconst.shm import convert_sh_descoteaux_tournier
from dipy.reconst.utils import convert_tensors
from dipy.tracking.streamlinespeed import length
from dipy.utils.tractogram import concatenate_tractogram
from dipy.workflows.workflow import Workflow


[docs] class IoInfoFlow(Workflow):
[docs] @classmethod def get_short_name(cls): return "io_info"
[docs] def run( self, input_files, b0_threshold=50, bvecs_tol=0.01, bshell_thr=100, reference=None, ): """Provides useful information about different files used in medical imaging. Any number of input files can be provided. The program identifies the type of file by its extension. Parameters ---------- input_files : variable string Any number of Nifti1, bvals or bvecs files. b0_threshold : float, optional Threshold used to find b0 volumes. bvecs_tol : float, optional Threshold used to check that norm(bvec) = 1 +/- bvecs_tol b-vectors are unit vectors. bshell_thr : float, optional Threshold for distinguishing b-values in different shells. reference : string, optional Reference anatomy for tck/vtk/fib/dpy file. support (.nii or .nii.gz). """ np.set_printoptions(3, suppress=True) io_it = self.get_io_iterator() for input_path in io_it: mult_ = len(input_path) logging.info(f"-----------{mult_ * '-'}") logging.info(f"Looking at {input_path}") logging.info(f"-----------{mult_ * '-'}") ipath_lower = input_path.lower() extension = os.path.splitext(ipath_lower)[1] if ipath_lower.endswith(".nii") or ipath_lower.endswith(".nii.gz"): data, affine, img, vox_sz, affcodes = load_nifti( input_path, return_img=True, return_voxsize=True, return_coords=True ) logging.info(f"Data size {data.shape}") logging.info(f"Data type {data.dtype}") if data.ndim == 3: logging.info( f"Data min {data.min()} max {data.max()} avg {data.mean()}" ) logging.info( f"2nd percentile {np.percentile(data, 2)} " f"98th percentile {np.percentile(data, 98)}" ) if data.ndim == 4: logging.info( f"Data min {data[..., 0].min()} " f"max {data[..., 0].max()} " f"avg {data[..., 0].mean()} of vol 0" ) msg = ( f"2nd percentile {np.percentile(data[..., 0], 2)} " f"98th percentile {np.percentile(data[..., 0], 98)} " f"of vol 0" ) logging.info(msg) logging.info(f"Native coordinate system {''.join(affcodes)}") logging.info(f"Affine Native to RAS matrix \n{affine}") logging.info(f"Voxel size {np.array(vox_sz)}") if np.sum(np.abs(np.diff(vox_sz))) > 0.1: msg = "Voxel size is not isotropic. Please reslice.\n" logging.warning(msg, stacklevel=2) if os.path.basename(input_path).lower().find("bval") > -1: bvals = np.loadtxt(input_path) logging.info(f"b-values \n{bvals}") logging.info(f"Total number of b-values {len(bvals)}") shells = np.sum(np.diff(np.sort(bvals)) > bshell_thr) logging.info(f"Number of gradient shells {shells}") logging.info( f"Number of b0s {np.sum(bvals <= b0_threshold)} " f"(b0_thr {b0_threshold})\n" ) if os.path.basename(input_path).lower().find("bvec") > -1: bvecs = np.loadtxt(input_path) logging.info(f"Bvectors shape on disk is {bvecs.shape}") rows, cols = bvecs.shape if rows < cols: bvecs = bvecs.T logging.info(f"Bvectors are \n{bvecs}") norms = np.array([np.linalg.norm(bvec) for bvec in bvecs]) res = np.where((norms <= 1 + bvecs_tol) & (norms >= 1 - bvecs_tol)) ncl1 = np.sum(norms < 1 - bvecs_tol) logging.info(f"Total number of unit bvectors {len(res[0])}") logging.info(f"Total number of non-unit bvectors {ncl1}\n") if extension in [".trk", ".tck", ".trx", ".vtk", ".vtp", ".fib", ".dpy"]: sft = None if extension in [".trk", ".trx"]: sft = load_tractogram(input_path, "same", bbox_valid_check=False) else: sft = load_tractogram(input_path, reference, bbox_valid_check=False) lengths_mm = list(length(sft.streamlines)) sft.to_voxmm() lengths, steps = [], [] for streamline in sft.streamlines: lengths += [len(streamline)] steps += [np.sqrt(np.sum(np.diff(streamline, axis=0) ** 2, axis=1))] steps = np.hstack(steps) logging.info(f"Number of streamlines: {len(sft)}") logging.info(f"min_length_mm: {float(np.min(lengths_mm))}") logging.info(f"mean_length_mm: {float(np.mean(lengths_mm))}") logging.info(f"max_length_mm: {float(np.max(lengths_mm))}") logging.info(f"std_length_mm: {float(np.std(lengths_mm))}") logging.info(f"min_length_nb_points: {float(np.min(lengths))}") logging.info("mean_length_nb_points: " f"{float(np.mean(lengths))}") logging.info(f"max_length_nb_points: {float(np.max(lengths))}") logging.info(f"std_length_nb_points: {float(np.std(lengths))}") logging.info(f"min_step_size: {float(np.min(steps))}") logging.info(f"mean_step_size: {float(np.mean(steps))}") logging.info(f"max_step_size: {float(np.max(steps))}") logging.info(f"std_step_size: {float(np.std(steps))}") logging.info( "data_per_point_keys: " f"{list(sft.data_per_point.keys())}" ) logging.info( "data_per_streamline_keys: " f"{list(sft.data_per_streamline.keys())}" ) np.set_printoptions()
[docs] class FetchFlow(Workflow):
[docs] @classmethod def get_short_name(cls): return "fetch"
[docs] @staticmethod def get_fetcher_datanames(): """Gets available dataset and function names. Returns ------- available_data: dict Available dataset and function names. """ fetcher_module = FetchFlow.load_module("dipy.data.fetcher") available_data = dict( { (name.replace("fetch_", ""), func) for name, func in getmembers(fetcher_module, isfunction) if name.lower().startswith("fetch_") and func is not fetcher_module.fetch_data if name.lower() not in ["fetch_hbn", "fetch_hcp"] } ) return available_data
[docs] @staticmethod def load_module(module_path): """Load / reload an external module. Parameters ---------- module_path: string the path to the module relative to the main script Returns ------- module: module object """ if module_path in sys.modules: return importlib.reload(sys.modules[module_path]) else: return importlib.import_module(module_path)
[docs] def run(self, data_names, out_dir=""): """Download files to folder and check their md5 checksums. To see all available datasets, please type "list" in data_names. Parameters ---------- data_names : variable string Any number of Nifti1, bvals or bvecs files. out_dir : string, optional Output directory. (default current directory) """ if out_dir: dipy_home = os.environ.get("DIPY_HOME", None) os.environ["DIPY_HOME"] = out_dir available_data = FetchFlow.get_fetcher_datanames() data_names = [name.lower() for name in data_names] if "all" in data_names: for name, fetcher_function in available_data.items(): logging.info("------------------------------------------") logging.info(f"Fetching at {name}") logging.info("------------------------------------------") fetcher_function() elif "list" in data_names: logging.info( "Please, select between the following data names: " f"{', '.join(available_data.keys())}" ) else: skipped_names = [] for data_name in data_names: if data_name not in available_data.keys(): skipped_names.append(data_name) continue logging.info("------------------------------------------") logging.info(f"Fetching at {data_name}") logging.info("------------------------------------------") available_data[data_name]() nb_success = len(data_names) - len(skipped_names) print("\n") logging.info(f"Fetched {nb_success} / {len(data_names)} Files ") if skipped_names: logging.warn(f"Skipped data name(s): {' '.join(skipped_names)}") logging.warn( "Please, select between the following data names: " f"{', '.join(available_data.keys())}" ) if out_dir: if dipy_home: os.environ["DIPY_HOME"] = dipy_home else: os.environ.pop("DIPY_HOME", None) # We load the module again so that if we run another one of these # in the same process, we don't have the env variable pointing # to the wrong place self.load_module("dipy.data.fetcher")
[docs] class SplitFlow(Workflow):
[docs] @classmethod def get_short_name(cls): return "split"
[docs] def run(self, input_files, vol_idx=0, out_dir="", out_split="split.nii.gz"): """Splits the input 4D file and extracts the required 3D volume. Parameters ---------- input_files : variable string Any number of Nifti1 files vol_idx : int, optional Index of the 3D volume to extract. out_dir : string, optional Output directory. (default current directory) out_split : string, optional Name of the resulting split volume """ io_it = self.get_io_iterator() for fpath, osplit in io_it: logging.info(f"Splitting {fpath}") data, affine, image = load_nifti(fpath, return_img=True) if vol_idx == 0: logging.info("Splitting and extracting 1st b0") split_vol = data[..., vol_idx] save_nifti(osplit, split_vol, affine, hdr=image.header) logging.info(f"Split volume saved as {osplit}")
[docs] class ConcatenateTractogramFlow(Workflow):
[docs] @classmethod def get_short_name(cls): return "concatracks"
[docs] def run( self, tractogram_files, reference=None, delete_dpv=False, delete_dps=False, delete_groups=False, check_space_attributes=True, preallocation=False, out_dir="", out_extension="trx", out_tractogram="concatenated_tractogram", ): """Concatenate multiple tractograms into one. Parameters ---------- tractogram_list : variable string The stateful tractogram filenames to concatenate reference : string, optional Reference anatomy for tck/vtk/fib/dpy file. support (.nii or .nii.gz). delete_dpv : bool, optional Delete dpv keys that do not exist in all the provided TrxFiles delete_dps : bool, optional Delete dps keys that do not exist in all the provided TrxFile delete_groups : bool, optional Delete all the groups that currently exist in the TrxFiles check_space_attributes : bool, optional Verify that dimensions and size of data are similar between all the TrxFiles preallocation : bool, optional Preallocated TrxFile has already been generated and is the first element in trx_list (Note: delete_groups must be set to True as well) out_dir : string, optional Output directory. (default current directory) out_extension : string, optional Extension of the resulting tractogram out_tractogram : string, optional Name of the resulting tractogram """ io_it = self.get_io_iterator() trx_list = [] has_group = False for fpath, _, _ in io_it: if fpath.lower().endswith(".trx") or fpath.lower().endswith(".trk"): reference = "same" if not reference: raise ValueError( "No reference provided. It is needed for tck," "fib, dpy or vtk files" ) tractogram_obj = load_tractogram(fpath, reference, bbox_valid_check=False) if not isinstance(tractogram_obj, tmm.TrxFile): tractogram_obj = tmm.TrxFile.from_sft(tractogram_obj) elif len(tractogram_obj.groups): has_group = True trx_list.append(tractogram_obj) trx = concatenate_tractogram( trx_list, delete_dpv=delete_dpv, delete_dps=delete_dps, delete_groups=delete_groups or not has_group, check_space_attributes=check_space_attributes, preallocation=preallocation, ) valid_extensions = ["trk", "trx", "tck", "fib", "dpy", "vtk"] if out_extension.lower() not in valid_extensions: raise ValueError( f"Invalid extension. Valid extensions are: {valid_extensions}" ) out_fpath = os.path.join(out_dir, f"{out_tractogram}.{out_extension}") save_tractogram(trx.to_sft(), out_fpath, bbox_valid_check=False)
[docs] class ConvertSHFlow(Workflow):
[docs] @classmethod def get_short_name(cls): return "convert_dipy_mrtrix"
[docs] def run( self, input_files, out_dir="", out_file="sh_convert_dipy_mrtrix_out.nii.gz", ): """Converts SH basis representation between DIPY and MRtrix3 formats. Because this conversion is equal to its own inverse, it can be used to convert in either direction: DIPY to MRtrix3 or vice versa. Parameters ---------- input_files : string Path to the input files. This path may contain wildcards to process multiple inputs at once. out_dir : string, optional Where the resulting file will be saved. (default '') out_file : string, optional Name of the result file to be saved. (default 'sh_convert_dipy_mrtrix_out.nii.gz') """ io_it = self.get_io_iterator() for in_file, out_file in io_it: data, affine, image = load_nifti(in_file, return_img=True) data = convert_sh_descoteaux_tournier(data) save_nifti(out_file, data, affine, hdr=image.header)
[docs] class ConvertTensorsFlow(Workflow):
[docs] @classmethod def get_short_name(cls): return "convert_tensors"
[docs] def run( self, tensor_files, from_format="mrtrix", to_format="dipy", out_dir=".", out_tensor="converted_tensor", ): """Converts tensor representation between different formats. Parameters ---------- tensor_files : variable string Any number of tensor files from_format : string, optional Format of the input tensor files. Valid options are 'dipy', 'mrtrix', 'ants', 'fsl'. to_format : string, optional Format of the output tensor files. Valid options are 'dipy', 'mrtrix', 'ants', 'fsl'. out_dir : string, optional Output directory. (default current directory) out_tensor : string, optional Name of the resulting tensor file """ io_it = self.get_io_iterator() for fpath, otensor in io_it: logging.info(f"Converting {fpath}") data, affine, image = load_nifti(fpath, return_img=True) data = convert_tensors(data, from_format, to_format) save_nifti(otensor, data, affine, hdr=image.header)
[docs] class ConvertTractogramFlow(Workflow):
[docs] @classmethod def get_short_name(cls): return "convert_tractogram"
[docs] def run( self, input_files, reference=None, pos_dtype="float32", offsets_dtype="uint32", out_dir="", out_tractogram="converted_tractogram.trk", ): """Converts tractogram between different formats. Parameters ---------- input_files : variable string Any number of tractogram files reference : string, optional Reference anatomy for tck/vtk/fib/dpy file. support (.nii or .nii.gz). pos_dtype : string, optional Data type of the tractogram points, used for vtk files. offsets_dtype : string, optional Data type of the tractogram offsets, used for vtk files. out_dir : string, optional Output directory. (default current directory) out_tractogram : string, optional Name of the resulting tractogram """ io_it = self.get_io_iterator() for fpath, otracks in io_it: in_extension = fpath.lower().split(".")[-1] out_extension = otracks.lower().split(".")[-1] if in_extension == out_extension: warnings.warn( "Input and output are the same file format. Skipping...", stacklevel=2, ) continue if not reference and in_extension in ["trx", "trk"]: reference = "same" if not reference and in_extension not in ["trx", "trk"]: raise ValueError( "No reference provided. It is needed for tck," "fib, dpy or vtk files" ) sft = load_tractogram(fpath, reference, bbox_valid_check=False) if out_extension != "trx": if out_extension == "vtk": if sft.streamlines._data.dtype.name != pos_dtype: sft.streamlines._data = sft.streamlines._data.astype(pos_dtype) if offsets_dtype == "uint64" or offsets_dtype == "uint32": offsets_dtype = offsets_dtype[1:] if sft.streamlines._offsets.dtype.name != offsets_dtype: sft.streamlines._offsets = sft.streamlines._offsets.astype( offsets_dtype ) save_tractogram(sft, otracks, bbox_valid_check=False) else: trx = tmm.TrxFile.from_sft(sft) if trx.streamlines._data.dtype.name != pos_dtype: trx.streamlines._data = trx.streamlines._data.astype(pos_dtype) if trx.streamlines._offsets.dtype.name != offsets_dtype: trx.streamlines._offsets = trx.streamlines._offsets.astype( offsets_dtype ) tmm.save(trx, otracks) trx.close()
[docs] class NiftisToPamFlow(Workflow):
[docs] @classmethod def get_short_name(cls): return "niftis_to_pam"
[docs] def run( self, peaks_dir_files, peaks_values_files, peaks_indices_files, shm_files=None, gfa_files=None, sphere_files=None, default_sphere_name="repulsion724", out_dir="", out_pam="peaks.pam5", ): """Convert multiple nifti files to a single pam5 file. Parameters ---------- peaks_dir_files : string Path to the input peaks directions volume. This path may contain wildcards to process multiple inputs at once. peaks_values_files : string Path to the input peaks values volume. This path may contain wildcards to process multiple inputs at once. peaks_indices_files : string Path to the input peaks indices volume. This path may contain wildcards to process multiple inputs at once. shm_files : string, optional Path to the input spherical harmonics volume. This path may contain wildcards to process multiple inputs at once. gfa_files : string, optional Path to the input generalized FA volume. This path may contain wildcards to process multiple inputs at once. sphere_files : string, optional Path to the input sphere vertices. This path may contain wildcards to process multiple inputs at once. If it is not define, default_sphere option will be used. default_sphere_name : string, optional Specify default sphere to use for spherical harmonics representation. This option can be superseded by sphere_files option. Possible options: ['symmetric362', 'symmetric642', 'symmetric724', 'repulsion724', 'repulsion100', 'repulsion200']. out_dir : string, optional Output directory (default input file directory). out_pam : string, optional Name of the peaks volume to be saved. """ io_it = self.get_io_iterator() msg = f"pam5 files saved in {out_dir or 'current directory'}" for fpeak_dirs, fpeak_values, fpeak_indices, opam in io_it: logging.info("Converting nifti files to pam5") peak_dirs, affine = load_nifti(fpeak_dirs) peak_values, _ = load_nifti(fpeak_values) peak_indices, _ = load_nifti(fpeak_indices) if sphere_files: xyz = np.loadtxt(sphere_files) sphere = Sphere(xyz=xyz) else: sphere = get_sphere(name=default_sphere_name) niftis_to_pam( affine=affine, peak_dirs=peak_dirs, sphere=sphere, peak_values=peak_values, peak_indices=peak_indices, pam_file=opam, ) logging.info(msg.replace("pam5", opam))
[docs] class TensorToPamFlow(Workflow):
[docs] @classmethod def get_short_name(cls): return "tensor_to_niftis"
[docs] def run( self, evals_files, evecs_files, sphere_files=None, default_sphere_name="repulsion724", out_dir="", out_pam="peaks.pam5", ): """Convert multiple tensor files(evals, evecs) to pam5 files. Parameters ---------- evals_files : string Path to the input eigen values volumes. This path may contain wildcards to process multiple inputs at once. evecs_files : string Path to the input eigen vectors volumes. This path may contain wildcards to process multiple inputs at once. sphere_files : string, optional Path to the input sphere vertices. This path may contain wildcards to process multiple inputs at once. If it is not define, default_sphere option will be used. default_sphere_name : string, optional Specify default sphere to use for spherical harmonics representation. This option can be superseded by sphere_files option. Possible options: ['symmetric362', 'symmetric642', 'symmetric724', 'repulsion724', 'repulsion100', 'repulsion200']. out_dir : string, optional Output directory (default input file directory). out_pam : string, optional Name of the peaks volume to be saved. """ io_it = self.get_io_iterator() msg = f"pam5 files saved in {out_dir or 'current directory'}" for fevals, fevecs, opam in io_it: logging.info("Converting tensor files to pam5...") evals, affine = load_nifti(fevals) evecs, _ = load_nifti(fevecs) if sphere_files: xyz = np.loadtxt(sphere_files) sphere = Sphere(xyz=xyz) else: sphere = get_sphere(name=default_sphere_name) tensor_to_pam(evals, evecs, affine, sphere=sphere, pam_file=opam) logging.info(msg.replace("pam5", opam))
[docs] class PamToNiftisFlow(Workflow):
[docs] @classmethod def get_short_name(cls): return "pam_to_niftis"
[docs] def run( self, pam_files, out_dir="", out_peaks_dir="peaks_dirs.nii.gz", out_peaks_values="peaks_values.nii.gz", out_peaks_indices="peaks_indices.nii.gz", out_shm="shm.nii.gz", out_gfa="gfa.nii.gz", out_sphere="sphere.txt", ): """Convert pam5 files to multiple nifti files. Parameters ---------- pam_files : string Path to the input peaks volumes. This path may contain wildcards to process multiple inputs at once. out_dir : string, optional Output directory (default input file directory). out_peaks_dir : string, optional Name of the peaks directions volume to be saved. out_peaks_values : string, optional Name of the peaks values volume to be saved. out_peaks_indices : string, optional Name of the peaks indices volume to be saved. out_shm : string, optional Name of the spherical harmonics volume to be saved. out_gfa : string, optional Generalized FA volume name to be saved. out_sphere : string, optional Sphere vertices name to be saved. """ io_it = self.get_io_iterator() msg = f"Nifti files saved in {out_dir or 'current directory'}" for ( ipam, opeaks_dir, opeaks_values, opeaks_indices, oshm, ogfa, osphere, ) in io_it: logging.info("Converting %s file to niftis...", ipam) pam = load_pam(ipam) pam_to_niftis( pam, fname_peaks_dir=opeaks_dir, fname_shm=oshm, fname_peaks_values=opeaks_values, fname_peaks_indices=opeaks_indices, fname_sphere=osphere, fname_gfa=ogfa, fname_b=pjoin(out_dir, "B.nii.gz"), fname_qa=pjoin(out_dir, "qa.nii.gz"), ) logging.info(msg)