Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/scilus/scilpy into lesion…
Browse files Browse the repository at this point in the history
…s_analysis
  • Loading branch information
frheault committed Jul 8, 2021
2 parents 4846a99 + f4ea3d3 commit d7cda04
Show file tree
Hide file tree
Showing 7 changed files with 330 additions and 122 deletions.
16 changes: 7 additions & 9 deletions scilpy/segment/voting_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,13 +288,13 @@ def __call__(self, input_tractogram_path, nbr_processes=1, seeds=None):
slr_transform_type, seed])

tmp_dir, tmp_memmap_filenames = streamlines_to_memmap(wb_streamlines)
del wb_streamlines
comb_param_cluster = product(self.tractogram_clustering_thr, seeds)

# Clustring is now parallelize
pool = multiprocessing.Pool(nbr_processes)
all_rbx_dict = pool.map(single_clusterize_and_rbx_init,
zip(repeat(wb_streamlines),
repeat(tmp_memmap_filenames),
zip(repeat(tmp_memmap_filenames),
comb_param_cluster,
repeat(self.nb_points)))
pool.close()
Expand Down Expand Up @@ -363,8 +363,6 @@ def single_clusterize_and_rbx_init(args):
Parameters
----------
wb_streamlines : list or ArraySequence
All streamlines of the tractogram to segment.
tmp_memmap_filename: tuple (3)
Temporary filename for the data, offsets and lengths.
Expand All @@ -381,11 +379,11 @@ def single_clusterize_and_rbx_init(args):
rbx : dict
Initialisation of the recobundles class using specific parameters.
"""
wb_streamlines = args[0]
tmp_memmap_filename = args[1]
clustering_thr = args[2][0]
seed = args[2][1]
nb_points = args[3]
tmp_memmap_filename = args[0]
wb_streamlines = reconstruct_streamlines_from_memmap(tmp_memmap_filename)
clustering_thr = args[1][0]
seed = args[1][1]
nb_points = args[2]

rbx = {}
base_thresholds = [45, 35, 25]
Expand Down
43 changes: 43 additions & 0 deletions scilpy/utils/transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-

from dipy.io.stateful_tractogram import StatefulTractogram
import numpy as np


def get_axis_flip_vector(flip_axes):
flip_vector = np.ones(3)
if 'x' in flip_axes:
flip_vector[0] = -1.0
if 'y' in flip_axes:
flip_vector[1] = -1.0
if 'z' in flip_axes:
flip_vector[2] = -1.0

return flip_vector


def get_shift_vector(sft):
dims = sft.space_attributes[1]
shift_vector = -1.0 * (np.array(dims) / 2.0)

return shift_vector


def flip_sft(sft, flip_axes):
flip_vector = get_axis_flip_vector(flip_axes)
shift_vector = get_shift_vector(sft)

flipped_streamlines = []

streamlines = sft.streamlines

for streamline in streamlines:
mod_streamline = streamline + shift_vector
mod_streamline *= flip_vector
mod_streamline -= shift_vector
flipped_streamlines.append(mod_streamline)

new_sft = StatefulTractogram.from_sft(flipped_streamlines, sft,
data_per_point=sft.data_per_point,
data_per_streamline=sft.data_per_streamline)
return new_sft
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
import argparse
import itertools
import json
import logging
import os
import shutil

import numpy as np

Expand Down Expand Up @@ -71,7 +68,8 @@ def main():
else:
all_matrices.append(tmp_mat)

output_measures_dict = {'SSD': [], 'correlation': [], 'w_dice_voxels' : [], 'dice_voxels' : []}
output_measures_dict = {'RMSE': [], 'correlation': [],
'w_dice_voxels': [], 'dice_voxels': []}

if args.single_compare:
if args.single_compare in args.in_matrices:
Expand All @@ -82,8 +80,8 @@ def main():
pairs = list(itertools.combinations(all_matrices, r=2))

for i in pairs:
ssd = np.sum((i[0] - i[1]) ** 2)
output_measures_dict['SSD'].append(ssd)
rmse = np.sqrt(np.mean((i[0]-i[1])**2))
output_measures_dict['RMSE'].append(rmse)
corrcoef = np.corrcoef(i[0].ravel(), i[1].ravel())
output_measures_dict['correlation'].append(corrcoef[0][1])
dice, w_dice = compute_dice_voxel(i[0], i[1])
Expand Down
208 changes: 208 additions & 0 deletions scripts/scil_fix_dsi_studio_trk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
This script is made to fix DSI-Studio TRK file (unknown space/convention) to
make it compatible with TrackVis, MI-Brain, Dipy Horizon (Stateful Tractogram).
The script either make it match with an anatomy from DSI-Studio (AC-PC aligned,
sometimes flipped) or if --in_native_fa is provided it moves it back to native
DWI space (this involved registration).
Since DSI-Studio sometimes leaves some skull around the brain, the --auto_crop
aims to stabilize registration. If this option fails, manually BET both FA.
Registration is more robust at resolution above 2mm (iso), be careful.
If you are fixing bundles, use this script once with --save_transfo and verify
results. Once satisfied, call the scripts on bundles using a bash for loop with
--load_transfo to save computation.
We recommand the --cut_invalid to remove invalid points of streamlines rather
removing entire streamlines.
This script was tested on various datasets and worked on all of them. However,
always verify the results and if a specific case does not work. Open an issue
on the Scilpy GitHub repository.
WARNING: This script is still experimental, DSI-Studio evolves quickly and
results may vary depending on the data itself as well as DSI-studio version.
"""

import argparse

from dipy.align.imaffine import (transform_centers_of_mass,
MutualInformationMetric,
AffineRegistration)
from dipy.align.transforms import RigidTransform3D
from dipy.io.stateful_tractogram import StatefulTractogram, Space
from dipy.io.utils import get_reference_info
from dipy.io.streamline import save_tractogram, load_tractogram
from dipy.reconst.utils import _roi_in_volume, _mask_from_roi
import nibabel as nib
import numpy as np

from scilpy.io.utils import (add_overwrite_arg,
assert_inputs_exist,
assert_outputs_exist)
from scilpy.utils.streamlines import (transform_warp_streamlines,
cut_invalid_streamlines)
from scilpy.utils.transformation import flip_sft


def _build_arg_parser():
p = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawTextHelpFormatter)

p.add_argument('in_dsi_tractogram',
help='Path of the input tractogram file from DSI studio '
'(.trk).')
p.add_argument('in_dsi_fa',
help='Path of the input FA from DSI Studio (.nii.gz).')
p.add_argument('out_tractogram',
help='Path of the output tractogram file.')
p.add_argument('--in_native_fa',
help='Path of the input FA from Dipy/MRtrix (.nii.gz).\n'
'Move the tractogram back to a "proper" space, include'
'registration.')
p.add_argument('--auto_crop', action='store_true',
help='If both FA are not already BET, perform registration \n'
'using a centered-cube crop to ignore the skull.\n'
'A good BET for both is more robust.')
transfo = p.add_mutually_exclusive_group()
transfo.add_argument('--save_transfo', metavar='FILE',
help='Save estimated transformation to avoid '
'recomputing (.txt).')
transfo.add_argument('--load_transfo', metavar='FILE',
help='Load estimated transformation to apply to other '
'files (.txt).')
invalid = p.add_mutually_exclusive_group()
invalid.add_argument('--cut_invalid', action='store_true',
help='Cut invalid streamlines rather than removing '
'them.\nKeep the longest segment only.')
invalid.add_argument('--remove_invalid', action='store_true',
help='Remove the streamlines landing out of the '
'bounding box.')
invalid.add_argument('--keep_invalid', action='store_true',
help='Keep the streamlines landing out of the '
'bounding box.')
add_overwrite_arg(p)

return p


def get_axis_shift_vector(flip_axes):
shift_vector = np.zeros(3)
if 'x' in flip_axes:
shift_vector[0] = -1.0
if 'y' in flip_axes:
shift_vector[1] = -1.0
if 'z' in flip_axes:
shift_vector[2] = -1.0

return shift_vector


def cube_crop_data(data):
shape = np.array(data.shape[:3])
roi_center = shape // 2
roi_radii = _roi_in_volume(shape, roi_center, shape // 3)
roi_mask = _mask_from_roi(shape, roi_center, roi_radii)

return data * roi_mask


def main():
parser = _build_arg_parser()
args = parser.parse_args()
if args.load_transfo and args.in_native_fa is None:
parser.error('When loading a transformation, the final reference is '
'needed, use --in_native_fa.')
assert_inputs_exist(parser, [args.in_dsi_tractogram, args.in_dsi_fa],
optional=args.in_native_fa)
assert_outputs_exist(parser, args, args.out_tractogram)

sft = load_tractogram(args.in_dsi_tractogram, 'same',
bbox_valid_check=False)

# LPS -> RAS convention in voxel space
sft.to_vox()
flip_axis = ['x', 'y']
sft.streamlines._data -= get_axis_shift_vector(flip_axis)
sft_fix = StatefulTractogram(sft.streamlines, args.in_dsi_fa,
Space.VOX)
sft_flip = flip_sft(sft_fix, flip_axis)

if not args.in_native_fa:
if args.cut_invalid:
sft_flip, _ = cut_invalid_streamlines(sft_flip)
elif args.remove_invalid:
sft_flip.remove_invalid_streamlines()
save_tractogram(sft_flip, args.out_tractogram,
bbox_valid_check=not args.keep_invalid)
else:
static_img = nib.load(args.in_native_fa)
static_data = static_img.get_fdata()
moving_img = nib.load(args.in_dsi_fa)
moving_data = moving_img.get_fdata()

# DSI-Studio flips the volume without changing the affine (I think)
# So this has to be reversed (not the same problem as above)
vox_order = get_reference_info(moving_img)[3]
flip_axis = []
if vox_order[0] == 'L':
moving_data = moving_data[::-1, :, :]
flip_axis.append('x')
if vox_order[1] == 'P':
moving_data = moving_data[:, ::-1, :]
flip_axis.append('y')
if vox_order[2] == 'I':
moving_data = moving_data[:, :, ::-1]
flip_axis.append('z')
sft_flip_back = flip_sft(sft_flip, flip_axis)

if args.load_transfo:
transfo = np.loadtxt(args.load_transfo)
else:
# Sometimes DSI studio has quite a lot of skull left
# Dipy Median Otsu does not work with FA/GFA
if args.auto_crop:
moving_data = cube_crop_data(moving_data)
static_data = cube_crop_data(static_data)

# Since DSI Studio register to AC/PC and does not save the
# transformation We must estimate the transformation, since it's
# rigid it is 'easy'
c_of_mass = transform_centers_of_mass(static_data, static_img.affine,
moving_data, moving_img.affine)

nbins = 32
sampling_prop = None
level_iters = [1000, 100, 10]
sigmas = [3.0, 2.0, 1.0]
factors = [3, 2, 1]
metric = MutualInformationMetric(nbins, sampling_prop)
affreg = AffineRegistration(metric=metric, level_iters=level_iters,
sigmas=sigmas, factors=factors)
transform = RigidTransform3D()
rigid = affreg.optimize(static_data, moving_data, transform, None,
static_img.affine, moving_img.affine,
starting_affine=c_of_mass.affine)
transfo = rigid.affine
if args.save_transfo:
np.savetxt(args.save_transfo, transfo)

new_sft = transform_warp_streamlines(sft_flip_back, transfo,
static_img, inverse=True,
remove_invalid=args.remove_invalid,
cut_invalid=args.cut_invalid)

if args.cut_invalid:
new_sft, _ = cut_invalid_streamlines(new_sft)
elif args.remove_invalid:
new_sft.remove_invalid_streamlines()
save_tractogram(new_sft, args.out_tractogram,
bbox_valid_check=not args.keep_invalid)


if __name__ == "__main__":
main()
40 changes: 1 addition & 39 deletions scripts/scil_flip_streamlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
add_overwrite_arg,
assert_inputs_exist,
assert_outputs_exist)
from scilpy.utils.transformation import flip_sft


def _build_arg_parser():
Expand All @@ -41,45 +42,6 @@ def _build_arg_parser():
return p


def get_axis_flip_vector(flip_axes):
flip_vector = np.ones(3)
if 'x' in flip_axes:
flip_vector[0] = -1.0
if 'y' in flip_axes:
flip_vector[1] = -1.0
if 'z' in flip_axes:
flip_vector[2] = -1.0

return flip_vector


def get_shift_vector(sft):
dims = sft.space_attributes[1]
shift_vector = -1.0 * (np.array(dims) / 2.0)

return shift_vector


def flip_sft(sft, flip_axes):
flip_vector = get_axis_flip_vector(flip_axes)
shift_vector = get_shift_vector(sft)

flipped_streamlines = []

streamlines = sft.streamlines

for streamline in streamlines:
mod_streamline = streamline + shift_vector
mod_streamline *= flip_vector
mod_streamline -= shift_vector
flipped_streamlines.append(mod_streamline)

new_sft = StatefulTractogram.from_sft(flipped_streamlines, sft,
data_per_point=sft.data_per_point,
data_per_streamline=sft.data_per_streamline)
return new_sft


def main():
parser = _build_arg_parser()
args = parser.parse_args()
Expand Down
Loading

0 comments on commit d7cda04

Please sign in to comment.