From dff3f5bd3618471a0b6013f37275698d8c8fd5f0 Mon Sep 17 00:00:00 2001 From: karp2601 Date: Thu, 22 Feb 2024 17:09:26 -0500 Subject: [PATCH 01/16] Start --- scilpy/io/utils.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/scilpy/io/utils.py b/scilpy/io/utils.py index ca0a4485a..fce5447cb 100644 --- a/scilpy/io/utils.py +++ b/scilpy/io/utils.py @@ -257,15 +257,22 @@ def add_sh_basis_args(parser, mandatory=False): mandatory: bool Whether this argument is mandatory. """ - choices = ['descoteaux07', 'tournier07'] + choices = ['descoteaux07', 'tournier07', 'descoteaux07_legaycy', + 'tournier07_legacy'] def_val = 'descoteaux07' help_msg = 'Spherical harmonics basis used for the SH coefficients. ' +\ - '\nMust be either \'descoteaux07\' or \'tournier07\'' +\ + '\nMust be either \'descoteaux07\', \'tournier07\', \n' +\ + '\'descoteaux07_legacy\' or \'tournier07_legacy\'' +\ ' [%(default)s]:\n' +\ - ' \'descoteaux07\': SH basis from the Descoteaux et al.\n' +\ - ' MRM 2007 paper\n' +\ - ' \'tournier07\' : SH basis from the Tournier et al.\n' +\ - ' NeuroImage 2007 paper.' + ' \'descoteaux07\' : SH basis from the Descoteaux et al.\n' +\ + ' MRM 2007 paper\n' +\ + ' \'tournier07\' : SH basis from the new Tournier et al.\n' +\ + ' NeuroImage 2019 paper, as in MRtrix 3.\n' +\ + ' \'descoteaux07_legacy\': SH basis from the legacy Dipy\n' +\ + ' implementation of the Descoteaux et al.\n' +\ + ' MRM 2007 paper\n' +\ + ' \'tournier07_legacy\' : SH basis from the legacy Tournier et al.\n' +\ + ' NeuroImage 2007 paper.' if mandatory: arg_name = 'sh_basis' From 6a6c46b262bc4145a8bcc18ed50c627289b060b0 Mon Sep 17 00:00:00 2001 From: Philippe Karan Date: Fri, 23 Feb 2024 11:30:21 -0500 Subject: [PATCH 02/16] Adding legacy options --- scilpy/io/utils.py | 54 +++++++++++++++++++++++--- scilpy/reconst/sh.py | 36 +++++++++++------ scripts/scil_aodf_metrics.py | 7 +++- scripts/scil_bundle_generate_priors.py | 13 +++++-- scripts/scil_fodf_ssst.py | 12 ++++-- scripts/scil_sh_convert.py | 32 +++++++-------- 6 files changed, 110 insertions(+), 44 deletions(-) diff --git a/scilpy/io/utils.py b/scilpy/io/utils.py index fce5447cb..c06f00a37 100644 --- a/scilpy/io/utils.py +++ b/scilpy/io/utils.py @@ -247,8 +247,9 @@ def add_bbox_arg(parser): 'streamlines).') -def add_sh_basis_args(parser, mandatory=False): - """Add spherical harmonics (SH) bases argument. +def add_sh_basis_args(parser, mandatory=False, input_output=False): + """ + Add spherical harmonics (SH) bases argument. Parameters ---------- @@ -256,11 +257,25 @@ def add_sh_basis_args(parser, mandatory=False): Parser. mandatory: bool Whether this argument is mandatory. + input_output: bool + Whether this argument should expect both input and output bases or not. + If set, the sh_basis argument will expect first the input basis, + followed by the output basis. """ - choices = ['descoteaux07', 'tournier07', 'descoteaux07_legaycy', + if input_output: + nargs = 2 + def_val = ['descoteaux07_legacy', 'tournier07'] + input_output_msg = '\nBoth the input and output bases are ' +\ + 'required, in that order.' + else: + nargs = 1 + def_val = 'descoteaux07_legacy' + input_output_msg = '' + + choices = ['descoteaux07', 'tournier07', 'descoteaux07_legacy', 'tournier07_legacy'] - def_val = 'descoteaux07' help_msg = 'Spherical harmonics basis used for the SH coefficients. ' +\ + input_output_msg +\ '\nMust be either \'descoteaux07\', \'tournier07\', \n' +\ '\'descoteaux07_legacy\' or \'tournier07_legacy\'' +\ ' [%(default)s]:\n' +\ @@ -279,11 +294,40 @@ def add_sh_basis_args(parser, mandatory=False): else: arg_name = '--sh_basis' - parser.add_argument(arg_name, + parser.add_argument(arg_name, nargs=nargs, choices=choices, default=def_val, help=help_msg) +def interpret_sh_basis(args): + """ + Interpret the input from args.sh_basis. If two SH bases are given, + both input/output sh_basis and is_legacy are returned. + + Parameters + ---------- + args : argparser + Argparser from a script. + + Returns + ------- + sh_basis : string + Spherical harmonic basis name. + is_legacy : bool + Whether or not the basis is in its legacy form. + """ + if len(args.sh_basis) == 2: + in_sh_basis = args.sh_basis[0].split("_")[0] + is_in_legacy = len(args.sh_basis[0].split("_")) == 2 + out_sh_basis = args.sh_basis[1].split("_")[0] + is_out_legacy = len(args.sh_basis[1].split("_")) == 2 + return in_sh_basis, is_in_legacy, out_sh_basis, is_out_legacy + else: + sh_basis = args.sh_basis[0].split("_")[0] + is_legacy = len(args.sh_basis[0].split("_")) == 2 + return sh_basis, is_legacy + + def add_nifti_screenshot_default_args( parser, slice_ids_mandatory=True, transparency_mask_mandatory=True ): diff --git a/scilpy/reconst/sh.py b/scilpy/reconst/sh.py index 3558eca4d..c3c7e0d66 100644 --- a/scilpy/reconst/sh.py +++ b/scilpy/reconst/sh.py @@ -201,8 +201,8 @@ def _peaks_from_sh_parallel(args): def peaks_from_sh(shm_coeff, sphere, mask=None, relative_peak_threshold=0.5, absolute_threshold=0, min_separation_angle=25, normalize_peaks=False, npeaks=5, - sh_basis_type='descoteaux07', nbr_processes=None, - full_basis=False, is_symmetric=True): + sh_basis_type='descoteaux07', is_legacy=True, + nbr_processes=None, full_basis=False, is_symmetric=True): """Computes peaks from given spherical harmonic coefficients Parameters @@ -236,6 +236,11 @@ def peaks_from_sh(shm_coeff, sphere, mask=None, relative_peak_threshold=0.5, Type of spherical harmonic basis used for `shm_coeff`. Either `descoteaux07` or `tournier07`. Default: `descoteaux07` + is_legacy: bool, optional + If true, this means that the input SH used a legacy basis definition + for backward compatibility with previous ``tournier07`` and + ``descoteaux07`` implementations. + Default: True nbr_processes: int, optional The number of subprocesses to use. Default: multiprocessing.cpu_count() @@ -252,7 +257,8 @@ def peaks_from_sh(shm_coeff, sphere, mask=None, relative_peak_threshold=0.5, peak_dirs, peak_values, peak_indices """ sh_order = order_from_ncoef(shm_coeff.shape[-1], full_basis) - B, _ = sh_to_sf_matrix(sphere, sh_order, sh_basis_type, full_basis) + B, _ = sh_to_sf_matrix(sphere, sh_order, sh_basis_type, full_basis, + legacy=is_legacy) data_shape = shm_coeff.shape if mask is None: @@ -490,8 +496,9 @@ def _convert_sh_basis_parallel(args): def convert_sh_basis(shm_coeff, sphere, mask=None, - input_basis='descoteaux07', nbr_processes=None, - is_input_legacy=True, is_output_legacy=True): + input_basis='descoteaux07', output_basis='tournier07', + is_input_legacy=True, is_output_legacy=False, + nbr_processes=None): """Converts spherical harmonic coefficients between two bases Parameters @@ -507,9 +514,10 @@ def convert_sh_basis(shm_coeff, sphere, mask=None, Type of spherical harmonic basis used for `shm_coeff`. Either `descoteaux07` or `tournier07`. Default: `descoteaux07` - nbr_processes: int, optional - The number of subprocesses to use. - Default: multiprocessing.cpu_count() + output_basis : str, optional + Type of spherical harmonic basis wanted as output. Either + `descoteaux07` or `tournier07`. + Default: `tournier07` is_input_legacy: bool, optional If true, this means that the input SH used a legacy basis definition for backward compatibility with previous ``tournier07`` and @@ -519,16 +527,20 @@ def convert_sh_basis(shm_coeff, sphere, mask=None, If true, this means that the output SH will use a legacy basis definition for backward compatibility with previous ``tournier07`` and ``descoteaux07`` implementations. - Default: True + Default: False + nbr_processes: int, optional + The number of subprocesses to use. + Default: multiprocessing.cpu_count() Returns ------- shm_coeff_array : np.ndarray Spherical harmonic coefficients in the desired basis. """ - output_basis = 'descoteaux07' \ - if input_basis == 'tournier07' \ - else 'tournier07' + if input_basis == output_basis and is_input_legacy == is_output_legacy: + logging.info('Input and output SH basis are equal, no SH basis ' + 'convertion needed.') + return shm_coeff sh_order = order_from_ncoef(shm_coeff.shape[-1]) B_in, _ = sh_to_sf_matrix(sphere, sh_order, input_basis, diff --git a/scripts/scil_aodf_metrics.py b/scripts/scil_aodf_metrics.py index cdde59580..e7d1a7419 100755 --- a/scripts/scil_aodf_metrics.py +++ b/scripts/scil_aodf_metrics.py @@ -44,7 +44,8 @@ add_verbose_arg, assert_inputs_exist, assert_outputs_exist, - add_overwrite_arg) + add_overwrite_arg, + interpret_sh_basis) from scilpy.io.image import get_data_as_mask @@ -146,6 +147,7 @@ def main(): sphere = get_sphere(args.sphere) + sh_basis, is_legacy = interpret_sh_basis(args) sh_order, full_basis = get_sh_order_and_fullness(sh.shape[-1]) if not full_basis and (args.asi_map or args.odd_power_map): parser.error('Invalid SH image. A full SH basis is expected.') @@ -175,7 +177,8 @@ def main(): # because v and -v are unique, we want twice # the usual default value (5) of npeaks npeaks=10, - sh_basis_type=args.sh_basis, + sh_basis_type=sh_basis, + is_legacy=is_legacy nbr_processes=args.nbr_processes, full_basis=full_basis, is_symmetric=False) diff --git a/scripts/scil_bundle_generate_priors.py b/scripts/scil_bundle_generate_priors.py index e9636188e..26ac79297 100755 --- a/scripts/scil_bundle_generate_priors.py +++ b/scripts/scil_bundle_generate_priors.py @@ -25,7 +25,8 @@ add_sh_basis_args, add_verbose_arg, assert_inputs_exist, - assert_outputs_exist) + assert_outputs_exist, + interpret_sh_basis) from scilpy.reconst.utils import find_order_from_nb_coeff from scilpy.tractanalysis.todi import TrackOrientationDensityImaging @@ -97,6 +98,7 @@ def main(): img_sh = nib.load(args.in_fodf) sh_shape = img_sh.shape sh_order = find_order_from_nb_coeff(sh_shape) + sh_basis, is_legacy = interpret_sh_basis(args) img_mask = nib.load(args.in_mask) sft = load_tractogram_with_reference(parser, args, args.in_bundle) @@ -128,13 +130,15 @@ def main(): sphere = get_sphere('repulsion724') priors_3d[sub_mask_3d] = sf_to_sh(todi_sf, sphere, sh_order=sh_order, - basis_type=args.sh_basis) + basis_type=sh_basis, + legacy=is_legacy) nib.save(nib.Nifti1Image(priors_3d, img_mask.affine), out_priors) del priors_3d input_sh_3d = img_sh.get_fdata(dtype=np.float32) input_sf_1d = sh_to_sf(input_sh_3d[sub_mask_3d], - sphere, sh_order=sh_order, basis_type=args.sh_basis) + sphere, sh_order=sh_order, + basis_type=sh_basis, legacy=is_legacy) # Creation of the enhanced-FOD (direction-wise multiplication) mult_sf_1d = input_sf_1d * todi_sf @@ -150,7 +154,8 @@ def main(): # Memory friendly saving input_sh_3d[sub_mask_3d] = sf_to_sh(mult_sf_1d, sphere, sh_order=sh_order, - basis_type=args.sh_basis) + basis_type=sh_basis, + legacy=is_legacy) nib.save(nib.Nifti1Image(input_sh_3d, img_mask.affine), out_efod) del input_sh_3d diff --git a/scripts/scil_fodf_ssst.py b/scripts/scil_fodf_ssst.py index 9c276b990..8907009cc 100755 --- a/scripts/scil_fodf_ssst.py +++ b/scripts/scil_fodf_ssst.py @@ -23,7 +23,7 @@ normalize_bvecs, is_normalized_bvecs) from scilpy.io.image import get_data_as_mask -from scilpy.io.utils import (add_overwrite_arg, +from scilpy.io.utils import (add_overwrite_arg, interpret_sh_basis, assert_inputs_exist, add_verbose_arg, assert_outputs_exist, add_force_b0_arg, add_sh_basis_args, add_processes_arg) @@ -87,6 +87,7 @@ def main(): raise ValueError("Mask is not the same shape as data.") sh_order = args.sh_order + sh_basis, is_legacy = interpret_sh_basis(args) # Checking data and sh_order b0_thr = check_b0_threshold( @@ -126,9 +127,12 @@ def main(): # Saving results shm_coeff = csd_fit.shm_coeff - if args.sh_basis == 'tournier07': - shm_coeff = convert_sh_basis(shm_coeff, reg_sphere, mask=mask, - nbr_processes=args.nbr_processes) + shm_coeff = convert_sh_basis(shm_coeff, reg_sphere, mask=mask, + input_basis='descoteaux07', + output_basis=sh_basis, + is_input_legacy=True, + is_output_legacy=is_legacy, + nbr_processes=args.nbr_processes) nib.save(nib.Nifti1Image(shm_coeff.astype(np.float32), vol.affine), args.out_fODF) diff --git a/scripts/scil_sh_convert.py b/scripts/scil_sh_convert.py index ca4bbc56a..423b0cb9d 100755 --- a/scripts/scil_sh_convert.py +++ b/scripts/scil_sh_convert.py @@ -2,10 +2,10 @@ # -*- coding: utf-8 -*- """ -Convert a SH file between the two commonly used bases -('descoteaux07' or 'tournier07'). The specified basis corresponds to the -input data basis. Note that by default, both legacy 'descoteaux07' and -legacy 'tournier07' bases will be assumed. For more information, see +Convert a SH file between the two of the following bases choices: +'descoteaux07', 'descoteaux07_legacy', 'tournier07' or 'tournier07_legacy'. +Using the sh_basis argument, both the input and the output SH bases must be +given, in the order. For more information about the bases, see https://dipy.org/documentation/1.4.0./theory/sh_basis/. Formerly: scil_convert_sh_basis.py @@ -21,7 +21,8 @@ from scilpy.reconst.sh import convert_sh_basis from scilpy.io.utils import (add_overwrite_arg, add_sh_basis_args, add_processes_arg, add_verbose_arg, - assert_inputs_exist, assert_outputs_exist) + assert_inputs_exist, assert_outputs_exist, + interpret_sh_basis) def _build_arg_parser(): @@ -33,14 +34,7 @@ def _build_arg_parser(): p.add_argument('out_sh', help='Output SH filename. (nii or nii.gz)') - p.add_argument('--in_sh_is_not_legacy', action='store_true', - help='If set, this means that the input SH are not encoded ' - 'with the legacy version of their SH basis.') - p.add_argument('--out_sh_is_not_legacy', action='store_true', - help='If set, this means that the output SH will not be ' - 'encoded with the legacy version of their SH basis.') - - add_sh_basis_args(p, mandatory=True) + add_sh_basis_args(p, mandatory=True, input_output=True) add_processes_arg(p) add_verbose_arg(p) add_overwrite_arg(p) @@ -60,11 +54,15 @@ def main(): img = nib.load(args.in_sh) data = img.get_fdata(dtype=np.float32) + in_sh_basis, is_in_legacy, out_sh_basis, is_out_legacy \ + = interpret_sh_basis(args) + new_data = convert_sh_basis(data, sphere, - input_basis=args.sh_basis, - nbr_processes=args.nbr_processes, - is_input_legacy=not args.in_sh_is_not_legacy, - is_output_legacy=not args.out_sh_is_not_legacy) + input_basis=in_sh_basis, + output_basis=out_sh_basis, + is_input_legacy=is_in_legacy, + is_output_legacy=is_out_legacy, + nbr_processes=args.nbr_processes) nib.save(nib.Nifti1Image(new_data, img.affine, header=img.header), args.out_sh) From e4b23be72eceb8d187d2b4d66f90aca011b78719 Mon Sep 17 00:00:00 2001 From: Philippe Karan Date: Fri, 23 Feb 2024 12:19:15 -0500 Subject: [PATCH 03/16] Adding legacy support for all scripts --- scilpy/denoise/asym_filtering.py | 30 ++++++++++++---- scilpy/io/utils.py | 2 +- scilpy/reconst/fodf.py | 13 ++++--- scilpy/reconst/sh.py | 15 +++++--- scilpy/reconst/utils.py | 6 ++-- scilpy/tractanalysis/afd_along_streamlines.py | 13 ++++--- scilpy/tractanalysis/todi.py | 7 ++-- scilpy/viz/scene_utils.py | 11 +++--- scripts/scil_bundle_mean_fixel_afd.py | 10 ++++-- .../scil_bundle_mean_fixel_afd_from_hdf5.py | 16 ++++++--- scripts/scil_dwi_to_sh.py | 9 +++-- scripts/scil_fodf_max_in_ventricles.py | 8 +++-- scripts/scil_fodf_memsmt.py | 35 ++++++++++++------- scripts/scil_fodf_metrics.py | 7 ++-- scripts/scil_fodf_msmt.py | 34 +++++++++++------- scripts/scil_qball_metrics.py | 7 ++-- scripts/scil_sh_to_aodf.py | 7 ++-- scripts/scil_sh_to_sf.py | 7 ++-- scripts/scil_tracking_pft.py | 7 ++-- scripts/scil_tractogram_compute_TODI.py | 9 +++-- scripts/scil_visualize_fodf.py | 8 +++-- 21 files changed, 180 insertions(+), 81 deletions(-) diff --git a/scilpy/denoise/asym_filtering.py b/scilpy/denoise/asym_filtering.py index bb185290f..b9377a1aa 100644 --- a/scilpy/denoise/asym_filtering.py +++ b/scilpy/denoise/asym_filtering.py @@ -11,6 +11,7 @@ def angle_aware_bilateral_filtering(in_sh, sh_order=8, sh_basis='descoteaux07', in_full_basis=False, + is_legacy=True, sphere_str='repulsion724', sigma_spatial=1.0, sigma_angular=1.0, sigma_range=0.5, use_gpu=True): @@ -27,6 +28,8 @@ def angle_aware_bilateral_filtering(in_sh, sh_order=8, Name of SH basis used. in_full_basis: bool, optional True if input is expressed in full SH basis. + is_legacy : bool, optional + Whether or not the SH basis is in its legacy form. sphere_str: str, optional Name of the DIPY sphere to use for sh to sf projection. sigma_spatial: float, optional @@ -46,6 +49,7 @@ def angle_aware_bilateral_filtering(in_sh, sh_order=8, if use_gpu and have_opencl: return angle_aware_bilateral_filtering_gpu(in_sh, sh_order, sh_basis, in_full_basis, + is_legacy, sphere_str, sigma_spatial, sigma_angular, sigma_range) elif use_gpu and not have_opencl: @@ -54,6 +58,7 @@ def angle_aware_bilateral_filtering(in_sh, sh_order=8, else: return angle_aware_bilateral_filtering_cpu(in_sh, sh_order, sh_basis, in_full_basis, + is_legacy, sphere_str, sigma_spatial, sigma_angular, sigma_range) @@ -61,6 +66,7 @@ def angle_aware_bilateral_filtering(in_sh, sh_order=8, def angle_aware_bilateral_filtering_gpu(in_sh, sh_order=8, sh_basis='descoteaux07', in_full_basis=False, + is_legacy=True, sphere_str='repulsion724', sigma_spatial=1.0, sigma_angular=1.0, @@ -78,6 +84,8 @@ def angle_aware_bilateral_filtering_gpu(in_sh, sh_order=8, Name of SH basis used. in_full_basis: bool, optional True if input is expressed in full SH basis. + is_legacy : bool, optional + Whether or not the SH basis is in its legacy form. sphere_str: str, optional Name of the DIPY sphere to use for sh to sf projection. sigma_spatial: float, optional @@ -104,11 +112,13 @@ def angle_aware_bilateral_filtering_gpu(in_sh, sh_order=8, sh_to_sf_mat = sh_to_sf_matrix(sphere, sh_order=sh_order, basis_type=sh_basis, full_basis=in_full_basis, + legacy=is_legacy, return_inv=False) _, sf_to_sh_mat = sh_to_sf_matrix(sphere, sh_order=sh_order, basis_type=sh_basis, full_basis=True, + legacy=is_legacy, return_inv=True) out_n_coeffs = sf_to_sh_mat.shape[1] @@ -150,6 +160,7 @@ def angle_aware_bilateral_filtering_gpu(in_sh, sh_order=8, def angle_aware_bilateral_filtering_cpu(in_sh, sh_order=8, sh_basis='descoteaux07', in_full_basis=False, + is_legacy=True, sphere_str='repulsion724', sigma_spatial=1.0, sigma_angular=1.0, @@ -168,6 +179,8 @@ def angle_aware_bilateral_filtering_cpu(in_sh, sh_order=8, Name of SH basis used. in_full_basis: bool, optional True if input is expressed in full SH basis. + is_legacy : bool, optional + Whether or not the SH basis is in its legacy form. sphere_str: str, optional Name of the DIPY sphere to use for sh to sf projection. sigma_spatial: float, optional @@ -194,7 +207,8 @@ def angle_aware_bilateral_filtering_cpu(in_sh, sh_order=8, nb_sf = len(sphere.vertices) B = sh_to_sf_matrix(sphere, sh_order=sh_order, basis_type=sh_basis, - return_inv=False, full_basis=in_full_basis) + return_inv=False, full_basis=in_full_basis, + legacy=is_legacy) mean_sf = np.zeros(in_sh.shape[:-1] + (nb_sf,)) @@ -209,7 +223,7 @@ def angle_aware_bilateral_filtering_cpu(in_sh, sh_order=8, # Convert back to SH coefficients _, B_inv = sh_to_sf_matrix(sphere, sh_order=sh_order, basis_type=sh_basis, - full_basis=True) + full_basis=True, legacy=is_legacy) out_sh = np.array([np.dot(i, B_inv) for i in mean_sf], dtype=in_sh.dtype) # By default, return only asymmetric SH return out_sh @@ -371,7 +385,7 @@ def _correlate_spatial(image_u, h_filter, sigma_range): def cosine_filtering(in_sh, sh_order=8, sh_basis='descoteaux07', - in_full_basis=False, dot_sharpness=1.0, + in_full_basis=False, is_legacy=True, dot_sharpness=1.0, sphere_str='repulsion724', sigma=1.0): """ Average the SH projected on a sphere using a first-neighbor gaussian @@ -389,6 +403,8 @@ def cosine_filtering(in_sh, sh_order=8, sh_basis='descoteaux07', SH basis of the input signal. in_full_basis: bool, optional True if the input is in full SH basis. + is_legacy : bool, optional + Whether or not the SH basis is in its legacy form. dot_sharpness: float, optional Exponent of the dot product. When set to 0.0, directions are not weighted by the dot product. @@ -411,13 +427,14 @@ def cosine_filtering(in_sh, sh_order=8, sh_basis='descoteaux07', nb_sf = len(sphere.vertices) mean_sf = np.zeros(np.append(in_sh.shape[:-1], nb_sf)) B = sh_to_sf_matrix(sphere, sh_order=sh_order, basis_type=sh_basis, - return_inv=False, full_basis=in_full_basis) + return_inv=False, full_basis=in_full_basis, + legacy=is_legacy) # We want a B matrix to project on an inverse sphere to have the sf on # the opposite hemisphere for a given vertice neg_B = sh_to_sf_matrix(Sphere(xyz=-sphere.vertices), sh_order=sh_order, basis_type=sh_basis, return_inv=False, - full_basis=in_full_basis) + full_basis=in_full_basis, legacy=is_legacy) # Apply filter to each sphere vertice for sf_i in range(nb_sf): @@ -435,7 +452,8 @@ def cosine_filtering(in_sh, sh_order=8, sh_basis='descoteaux07', # Convert back to SH coefficients _, B_inv = sh_to_sf_matrix(sphere, sh_order=sh_order, basis_type=sh_basis, - full_basis=True) + full_basis=True, + legacy=is_legacy) out_sh = np.array([np.dot(i, B_inv) for i in mean_sf], dtype=in_sh.dtype) return out_sh diff --git a/scilpy/io/utils.py b/scilpy/io/utils.py index c06f00a37..94fa487f9 100644 --- a/scilpy/io/utils.py +++ b/scilpy/io/utils.py @@ -314,7 +314,7 @@ def interpret_sh_basis(args): sh_basis : string Spherical harmonic basis name. is_legacy : bool - Whether or not the basis is in its legacy form. + Whether or not the SH basis is in its legacy form. """ if len(args.sh_basis) == 2: in_sh_basis = args.sh_basis[0].split("_")[0] diff --git a/scilpy/reconst/fodf.py b/scilpy/reconst/fodf.py index a64206a0c..f683d27e9 100644 --- a/scilpy/reconst/fodf.py +++ b/scilpy/reconst/fodf.py @@ -14,7 +14,8 @@ cvx, have_cvxpy, _ = optional_package("cvxpy") -def get_ventricles_max_fodf(data, fa, md, zoom, args): +def get_ventricles_max_fodf(data, fa, md, zoom, sh_basis, args, + is_legacy=True): """ Compute mean maximal fodf value in ventricules. Given heuristics thresholds on FA and MD values, finds the @@ -30,9 +31,13 @@ def get_ventricles_max_fodf(data, fa, md, zoom, args): FA (Fractional Anisotropy) volume from DTI md: ndarray (x, y, z) MD (Mean Diffusivity) volume from DTI - vol: int > 0 - Maximum Nnumber of voxels used to compute the mean. + zoom: int > 0 + Maximum number of voxels used to compute the mean. 1000 works well at 2x2x2 = 8 mm3 + sh_basis: str + Either 'tournier07' or 'descoteaux07' + is_legacy : bool, optional + Whether or not the SH basis is in its legacy form. Returns ------- @@ -42,7 +47,7 @@ def get_ventricles_max_fodf(data, fa, md, zoom, args): order = find_order_from_nb_coeff(data) sphere = get_sphere('repulsion100') - b_matrix = get_b_matrix(order, sphere, args.sh_basis) + b_matrix = get_b_matrix(order, sphere, sh_basis, is_legacy=is_legacy) sum_of_max = 0 count = 0 diff --git a/scilpy/reconst/sh.py b/scilpy/reconst/sh.py index c3c7e0d66..056209d62 100644 --- a/scilpy/reconst/sh.py +++ b/scilpy/reconst/sh.py @@ -20,7 +20,7 @@ def compute_sh_coefficients(dwi, gradient_table, sh_order=4, basis_type='descoteaux07', smooth=0.006, use_attenuation=False, force_b0_threshold=False, - mask=None, sphere=None): + mask=None, sphere=None, is_legacy=True): """Fit a diffusion signal with spherical harmonics coefficients. Parameters @@ -44,6 +44,8 @@ def compute_sh_coefficients(dwi, gradient_table, sh_order=4, and reconstruction. sphere: Sphere Dipy object. If not provided, will use Sphere(xyz=bvecs). + is_legacy : bool, optional + Whether or not the SH basis is in its legacy form. Returns ------- @@ -84,7 +86,8 @@ def compute_sh_coefficients(dwi, gradient_table, sh_order=4, sphere = Sphere(xyz=bvecs) # Fit SH - sh = sf_to_sh(weights, sphere, sh_order, basis_type, smooth=smooth) + sh = sf_to_sh(weights, sphere, sh_order, basis_type, smooth=smooth, + legacy=is_legacy) # Apply mask if mask is not None: @@ -599,6 +602,7 @@ def _convert_sh_to_sf_parallel(args): def convert_sh_to_sf(shm_coeff, sphere, mask=None, dtype="float32", input_basis='descoteaux07', input_full_basis=False, + is_input_legacy=True, nbr_processes=multiprocessing.cpu_count()): """Converts spherical harmonic coefficients to an SF sphere @@ -618,9 +622,11 @@ def convert_sh_to_sf(shm_coeff, sphere, mask=None, dtype="float32", Type of spherical harmonic basis used for `shm_coeff`. Either `descoteaux07` or `tournier07`. Default: `descoteaux07` - input_full_basis : bool + input_full_basis : bool, optional If True, use a full SH basis (even and odd orders) for the input SH coefficients. + is_input_legacy : bool, optional + Whether or not the input basis is in its legacy form. nbr_processes: int, optional The number of subprocesses to use. Default: multiprocessing.cpu_count() @@ -636,7 +642,8 @@ def convert_sh_to_sf(shm_coeff, sphere, mask=None, dtype="float32", sh_order = order_from_ncoef(shm_coeff.shape[-1], full_basis=input_full_basis) B_in, _ = sh_to_sf_matrix(sphere, sh_order, basis_type=input_basis, - full_basis=input_full_basis) + full_basis=input_full_basis, + legacy=is_input_legacy) B_in = B_in.astype(dtype) data_shape = shm_coeff.shape diff --git a/scilpy/reconst/utils.py b/scilpy/reconst/utils.py index 1de05fcc2..62abbb7e2 100644 --- a/scilpy/reconst/utils.py +++ b/scilpy/reconst/utils.py @@ -46,12 +46,14 @@ def _honor_authorsnames_sh_basis(sh_basis_type): return sh_basis -def get_b_matrix(order, sphere, sh_basis_type, return_all=False): +def get_b_matrix(order, sphere, sh_basis_type, return_all=False, + is_legacy=True): sh_basis = _honor_authorsnames_sh_basis(sh_basis_type) sph_harm_basis = sph_harm_lookup.get(sh_basis) if sph_harm_basis is None: raise ValueError("Invalid basis name.") - b_matrix, m, n = sph_harm_basis(order, sphere.theta, sphere.phi) + b_matrix, m, n = sph_harm_basis(order, sphere.theta, sphere.phi, + legacy=is_legacy) if return_all: return b_matrix, m, n return b_matrix diff --git a/scilpy/tractanalysis/afd_along_streamlines.py b/scilpy/tractanalysis/afd_along_streamlines.py index 06c019155..9b84633b3 100644 --- a/scilpy/tractanalysis/afd_along_streamlines.py +++ b/scilpy/tractanalysis/afd_along_streamlines.py @@ -8,7 +8,8 @@ from scilpy.tractanalysis.grid_intersections import grid_intersections -def afd_map_along_streamlines(sft, fodf, fodf_basis, length_weighting): +def afd_map_along_streamlines(sft, fodf, fodf_basis, length_weighting, + is_legacy=True): """ Compute the mean Apparent Fiber Density (AFD) and mean Radial fODF (radfODF) maps along a bundle. @@ -35,7 +36,8 @@ def afd_map_along_streamlines(sft, fodf, fodf_basis, length_weighting): afd_sum, rd_sum, weights = \ afd_and_rd_sums_along_streamlines(sft, fodf, fodf_basis, - length_weighting) + length_weighting, + is_legacy=is_legacy) non_zeros = np.nonzero(afd_sum) weights_nz = weights[non_zeros] @@ -46,7 +48,7 @@ def afd_map_along_streamlines(sft, fodf, fodf_basis, length_weighting): def afd_and_rd_sums_along_streamlines(sft, fodf, fodf_basis, - length_weighting): + length_weighting, is_legacy=True): """ Compute the mean Apparent Fiber Density (AFD) and mean Radial fODF (radfODF) maps along a bundle. @@ -62,6 +64,8 @@ def afd_and_rd_sums_along_streamlines(sft, fodf, fodf_basis, Has to be descoteaux07 or tournier07. length_weighting : bool If set, will weigh the AFD values according to segment lengths. + is_legacy : bool, optional + Whether or not the SH basis is in its legacy form. Returns ------- @@ -79,7 +83,8 @@ def afd_and_rd_sums_along_streamlines(sft, fodf, fodf_basis, fodf_data = fodf.get_fdata(dtype=np.float32) order = find_order_from_nb_coeff(fodf_data) sphere = get_sphere('repulsion724') - b_matrix, _, n = get_b_matrix(order, sphere, fodf_basis, return_all=True) + b_matrix, _, n = get_b_matrix(order, sphere, fodf_basis, return_all=True, + is_legacy=is_legacy) legendre0_at_n = lpn(order, 0)[0][n] sphere_norm = np.linalg.norm(sphere.vertices) diff --git a/scilpy/tractanalysis/todi.py b/scilpy/tractanalysis/todi.py index d8f2597b3..ec643e0ef 100644 --- a/scilpy/tractanalysis/todi.py +++ b/scilpy/tractanalysis/todi.py @@ -230,7 +230,8 @@ def normalize_todi_per_voxel(self, p_norm=2): self.todi = todi_u.p_normalize_vectors(self.todi, p_norm) return self.todi - def get_sh(self, sh_basis, sh_order, smooth=0.006, full_basis=False): + def get_sh(self, sh_basis, sh_order, smooth=0.006, full_basis=False, + is_legacy=True): """Spherical Harmonics (SH) coefficients of the TODI map Compute the SH representation of the TODI map, @@ -249,6 +250,8 @@ def get_sh(self, sh_basis, sh_order, smooth=0.006, full_basis=False): smooth : float, optional Smoothing factor for the conversion, Lambda-regularization in the SH fit (default 0.006). + is_legacy : bool, optional + Whether or not the SH basis is in its legacy form. Returns ------- @@ -267,7 +270,7 @@ def get_sh(self, sh_basis, sh_order, smooth=0.006, full_basis=False): """ return sf_to_sh(self.todi, self.sphere, sh_order=sh_order, basis_type=sh_basis, full_basis=full_basis, - smooth=smooth) + smooth=smooth, legacy=is_legacy) def reshape_to_3d(self, img_voxelly_masked): """Reshape a complex ravelled image to 3D. diff --git a/scilpy/viz/scene_utils.py b/scilpy/viz/scene_utils.py index dd1a9c31f..ef9448fc3 100644 --- a/scilpy/viz/scene_utils.py +++ b/scilpy/viz/scene_utils.py @@ -129,7 +129,8 @@ def set_display_extent(slicer_actor, orientation, volume_shape, slice_index): def create_odf_slicer(sh_fodf, orientation, slice_index, mask, sphere, nb_subdivide, sh_order, sh_basis, full_basis, scale, radial_scale, norm, colormap, sh_variance=None, - variance_k=1, variance_color=(255, 255, 255)): + variance_k=1, variance_color=(255, 255, 255), + is_legacy=True): """ Create a ODF slicer actor displaying a fODF slice. The input volume is a 3-dimensional grid containing the SH coefficients of the fODF for each @@ -171,6 +172,8 @@ def create_odf_slicer(sh_fodf, orientation, slice_index, mask, sphere, Factor that multiplies sqrt(variance). variance_color : tuple, optional Color of the variance fODF data, in RGB. + is_legacy : bool, optional + Whether or not the SH basis is in its legacy form. Returns ------- @@ -183,15 +186,15 @@ def create_odf_slicer(sh_fodf, orientation, slice_index, mask, sphere, # SH coefficients to SF coefficients matrix B_mat = sh_to_sf_matrix(sphere, sh_order, sh_basis, - full_basis, return_inv=False) + full_basis, return_inv=False, legacy=is_legacy) var_actor = None if sh_variance is not None: fodf = sh_to_sf(sh_fodf, sphere, sh_order, sh_basis, - full_basis=full_basis) + full_basis=full_basis, legacy=is_legacy) fodf_var = sh_to_sf(sh_variance, sphere, sh_order, sh_basis, - full_basis=full_basis) + full_basis=full_basis, legacy=is_legacy) fodf_uncertainty = fodf + variance_k * np.sqrt(np.clip(fodf_var, 0, None)) # normalise fodf and variance diff --git a/scripts/scil_bundle_mean_fixel_afd.py b/scripts/scil_bundle_mean_fixel_afd.py index 1a12ebd50..462a2860d 100755 --- a/scripts/scil_bundle_mean_fixel_afd.py +++ b/scripts/scil_bundle_mean_fixel_afd.py @@ -22,7 +22,8 @@ from scilpy.io.streamlines import load_tractogram_with_reference from scilpy.io.utils import (add_overwrite_arg, add_sh_basis_args, add_reference_arg, add_verbose_arg, - assert_inputs_exist, assert_outputs_exist) + assert_inputs_exist, assert_outputs_exist, + interpret_sh_basis) from scilpy.tractanalysis.afd_along_streamlines \ import afd_map_along_streamlines @@ -68,11 +69,14 @@ def main(): sft = load_tractogram_with_reference(parser, args, args.in_bundle) fodf_img = nib.load(args.in_fodf) + sh_basis, is_legacy = interpret_sh_basis(args) + afd_mean_map, rd_mean_map = afd_map_along_streamlines( sft, fodf_img, - args.sh_basis, - args.length_weighting) + sh_basis, + args.length_weighting, + is_legacy=is_legacy) nib.Nifti1Image(afd_mean_map.astype(np.float32), fodf_img.affine).to_filename(args.afd_mean_map) diff --git a/scripts/scil_bundle_mean_fixel_afd_from_hdf5.py b/scripts/scil_bundle_mean_fixel_afd_from_hdf5.py index 5512a861e..ce282c79c 100755 --- a/scripts/scil_bundle_mean_fixel_afd_from_hdf5.py +++ b/scripts/scil_bundle_mean_fixel_afd_from_hdf5.py @@ -33,6 +33,7 @@ add_verbose_arg, assert_inputs_exist, assert_outputs_exist, + interpret_sh_basis, validate_nbr_processes) from scilpy.tractanalysis.afd_along_streamlines \ import afd_map_along_streamlines @@ -53,6 +54,7 @@ def _afd_rd_wrapper(args): fodf_img = args[2] sh_basis = args[3] length_weighting = args[4] + is_legacy = args[5] with h5py.File(in_hdf5_filename, 'r') as in_hdf5_file: affine = in_hdf5_file.attrs['affine'] @@ -67,7 +69,8 @@ def _afd_rd_wrapper(args): origin=Origin.TRACKVIS) afd_mean_map, rd_mean_map = afd_map_along_streamlines(sft, fodf_img, sh_basis, - length_weighting) + length_weighting, + is_legacy=is_legacy) afd_mean = np.average(afd_mean_map[afd_mean_map > 0]) return key, afd_mean @@ -104,6 +107,7 @@ def main(): assert_outputs_exist(parser, args, [args.out_hdf5]) nbr_cpu = validate_nbr_processes(parser, args) + sh_basis, is_legacy = interpret_sh_basis(args) # HDF5 will not overwrite the file if os.path.isfile(args.out_hdf5): @@ -125,8 +129,9 @@ def main(): results_list = [] for key in keys: results_list.append(_afd_rd_wrapper([args.in_hdf5, key, fodf_img, - args.sh_basis, - args.length_weighting])) + sh_basis, + args.length_weighting, + is_legacy])) else: pool = multiprocessing.Pool(nbr_cpu) @@ -134,8 +139,9 @@ def main(): zip(itertools.repeat(args.in_hdf5), keys, itertools.repeat(fodf_img), - itertools.repeat(args.sh_basis), - itertools.repeat(args.length_weighting))) + itertools.repeat(sh_basis), + itertools.repeat(args.length_weighting), + itertools.repeat(is_legacy))) pool.close() pool.join() diff --git a/scripts/scil_dwi_to_sh.py b/scripts/scil_dwi_to_sh.py index c2ec20564..408ad91bc 100755 --- a/scripts/scil_dwi_to_sh.py +++ b/scripts/scil_dwi_to_sh.py @@ -18,7 +18,8 @@ from scilpy.io.image import get_data_as_mask from scilpy.io.utils import (add_force_b0_arg, add_overwrite_arg, add_sh_basis_args, assert_inputs_exist, - add_verbose_arg, assert_outputs_exist) + add_verbose_arg, assert_outputs_exist, + interpret_sh_basis) from scilpy.reconst.sh import compute_sh_coefficients @@ -68,14 +69,16 @@ def main(): bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec) gtab = gradient_table(args.in_bval, args.in_bvec, b0_threshold=bvals.min()) + sh_basis, is_legacy = interpret_sh_basis(args) + mask = None if args.mask: mask = get_data_as_mask(nib.load(args.mask), dtype=bool) - sh = compute_sh_coefficients(dwi, gtab, args.sh_order, args.sh_basis, + sh = compute_sh_coefficients(dwi, gtab, args.sh_order, sh_basis, args.smooth, use_attenuation=args.use_attenuation, - mask=mask) + mask=mask, is_legacy=is_legacy) nib.save(nib.Nifti1Image(sh.astype(np.float32), vol.affine), args.out_sh) diff --git a/scripts/scil_fodf_max_in_ventricles.py b/scripts/scil_fodf_max_in_ventricles.py index 40a070333..069c0de79 100755 --- a/scripts/scil_fodf_max_in_ventricles.py +++ b/scripts/scil_fodf_max_in_ventricles.py @@ -18,7 +18,8 @@ from scilpy.io.utils import (add_overwrite_arg, add_sh_basis_args, add_verbose_arg, - assert_inputs_exist, assert_outputs_exist) + assert_inputs_exist, assert_outputs_exist, + interpret_sh_basis) from scilpy.reconst.fodf import get_ventricles_max_fodf EPILOG = """ @@ -88,7 +89,10 @@ def main(): img_md = nib.load(args.in_md) md = img_md.get_fdata(dtype=np.float32) - value, mask = get_ventricles_max_fodf(fodf, fa, md, zoom, args) + sh_basis, is_legacy = interpret_sh_basis(args) + + value, mask = get_ventricles_max_fodf(fodf, fa, md, zoom, sh_basis, args, + is_legacy=is_legacy) if args.mask_output: img = nib.Nifti1Image(np.array(mask, 'float32'), affine) diff --git a/scripts/scil_fodf_memsmt.py b/scripts/scil_fodf_memsmt.py index 543a925f5..b372aaccf 100755 --- a/scripts/scil_fodf_memsmt.py +++ b/scripts/scil_fodf_memsmt.py @@ -51,7 +51,8 @@ from scilpy.io.image import get_data_as_mask from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist, assert_outputs_exist, add_sh_basis_args, - add_processes_arg, add_verbose_arg) + add_processes_arg, add_verbose_arg, + interpret_sh_basis) from scilpy.reconst.fodf import fit_from_model from scilpy.reconst.sh import convert_sh_basis @@ -177,6 +178,7 @@ def main(): raise ValueError("Mask is not the same shape as data.") sh_order = args.sh_order + sh_basis, is_legacy = interpret_sh_basis(args) # Checking data and sh_order if data.shape[-1] < (sh_order + 1) * (sh_order + 2) / 2: @@ -250,27 +252,36 @@ def main(): # Saving results if args.wm_out_fODF: wm_coeff = shm_coeff[..., 2:] - if args.sh_basis == 'tournier07': - wm_coeff = convert_sh_basis(wm_coeff, reg_sphere, mask=mask, - nbr_processes=args.nbr_processes) + wm_coeff = convert_sh_basis(wm_coeff, reg_sphere, mask=mask, + input_basis='descoteaux07', + output_basis=sh_basis, + is_input_legacy=True, + is_output_legacy=is_legacy, + nbr_processes=args.nbr_processes) nib.save(nib.Nifti1Image(wm_coeff.astype(np.float32), affine), args.wm_out_fODF) if args.gm_out_fODF: gm_coeff = shm_coeff[..., 1] - if args.sh_basis == 'tournier07': - gm_coeff = gm_coeff.reshape(gm_coeff.shape + (1,)) - gm_coeff = convert_sh_basis(gm_coeff, reg_sphere, mask=mask, - nbr_processes=args.nbr_processes) + gm_coeff = gm_coeff.reshape(gm_coeff.shape + (1,)) + gm_coeff = convert_sh_basis(gm_coeff, reg_sphere, mask=mask, + input_basis='descoteaux07', + output_basis=sh_basis, + is_input_legacy=True, + is_output_legacy=is_legacy, + nbr_processes=args.nbr_processes) nib.save(nib.Nifti1Image(gm_coeff.astype(np.float32), affine), args.gm_out_fODF) if args.csf_out_fODF: csf_coeff = shm_coeff[..., 0] - if args.sh_basis == 'tournier07': - csf_coeff = csf_coeff.reshape(csf_coeff.shape + (1,)) - csf_coeff = convert_sh_basis(csf_coeff, reg_sphere, mask=mask, - nbr_processes=args.nbr_processes) + csf_coeff = csf_coeff.reshape(csf_coeff.shape + (1,)) + csf_coeff = convert_sh_basis(csf_coeff, reg_sphere, mask=mask, + input_basis='descoteaux07', + output_basis=sh_basis, + is_input_legacy=True, + is_output_legacy=is_legacy, + nbr_processes=args.nbr_processes) nib.save(nib.Nifti1Image(csf_coeff.astype(np.float32), affine), args.csf_out_fODF) diff --git a/scripts/scil_fodf_metrics.py b/scripts/scil_fodf_metrics.py index 2aa032f57..163ab5f63 100755 --- a/scripts/scil_fodf_metrics.py +++ b/scripts/scil_fodf_metrics.py @@ -44,7 +44,8 @@ from scilpy.io.image import get_data_as_mask from scilpy.io.utils import (add_overwrite_arg, add_sh_basis_args, add_processes_arg, add_verbose_arg, - assert_inputs_exist, assert_outputs_exist) + assert_inputs_exist, assert_outputs_exist, + interpret_sh_basis) from scilpy.reconst.sh import peaks_from_sh, maps_from_sh @@ -146,6 +147,7 @@ def main(): raise ValueError("Mask is not the same shape as data.") sphere = get_sphere(args.sphere) + sh_basis, is_legacy = interpret_sh_basis(args) # Computing peaks peak_dirs, peak_values, \ @@ -156,7 +158,8 @@ def main(): absolute_threshold=args.a_threshold, min_separation_angle=25, normalize_peaks=False, - sh_basis_type=args.sh_basis, + sh_basis_type=sh_basis, + is_legacy=is_legacy, nbr_processes=args.nbr_processes) # Computing maps diff --git a/scripts/scil_fodf_msmt.py b/scripts/scil_fodf_msmt.py index 8143fd0d3..c8e68d5dc 100755 --- a/scripts/scil_fodf_msmt.py +++ b/scripts/scil_fodf_msmt.py @@ -36,7 +36,7 @@ from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist, assert_outputs_exist, add_force_b0_arg, add_sh_basis_args, add_processes_arg, - add_verbose_arg) + add_verbose_arg, interpret_sh_basis) from scilpy.reconst.fodf import fit_from_model from scilpy.reconst.sh import convert_sh_basis @@ -144,6 +144,7 @@ def main(): tol = args.tolerance sh_order = args.sh_order + sh_basis, is_legacy = interpret_sh_basis(args) # Checking data and sh_order b0_thr = check_b0_threshold( @@ -217,27 +218,36 @@ def main(): # Saving results if args.wm_out_fODF: wm_coeff = shm_coeff[..., 2:] - if args.sh_basis == 'tournier07': - wm_coeff = convert_sh_basis(wm_coeff, reg_sphere, mask=mask, - nbr_processes=args.nbr_processes) + wm_coeff = convert_sh_basis(wm_coeff, reg_sphere, mask=mask, + input_basis='descoteaux07', + output_basis=sh_basis, + is_input_legacy=True, + is_output_legacy=is_legacy, + nbr_processes=args.nbr_processes) nib.save(nib.Nifti1Image(wm_coeff.astype(np.float32), vol.affine), args.wm_out_fODF) if args.gm_out_fODF: gm_coeff = shm_coeff[..., 1] - if args.sh_basis == 'tournier07': - gm_coeff = gm_coeff.reshape(gm_coeff.shape + (1,)) - gm_coeff = convert_sh_basis(gm_coeff, reg_sphere, mask=mask, - nbr_processes=args.nbr_processes) + gm_coeff = gm_coeff.reshape(gm_coeff.shape + (1,)) + gm_coeff = convert_sh_basis(gm_coeff, reg_sphere, mask=mask, + input_basis='descoteaux07', + output_basis=sh_basis, + is_input_legacy=True, + is_output_legacy=is_legacy, + nbr_processes=args.nbr_processes) nib.save(nib.Nifti1Image(gm_coeff.astype(np.float32), vol.affine), args.gm_out_fODF) if args.csf_out_fODF: csf_coeff = shm_coeff[..., 0] - if args.sh_basis == 'tournier07': - csf_coeff = csf_coeff.reshape(csf_coeff.shape + (1,)) - csf_coeff = convert_sh_basis(csf_coeff, reg_sphere, mask=mask, - nbr_processes=args.nbr_processes) + csf_coeff = csf_coeff.reshape(csf_coeff.shape + (1,)) + csf_coeff = convert_sh_basis(csf_coeff, reg_sphere, mask=mask, + input_basis='descoteaux07', + output_basis=sh_basis, + is_input_legacy=True, + is_output_legacy=is_legacy, + nbr_processes=args.nbr_processes) nib.save(nib.Nifti1Image(csf_coeff.astype(np.float32), vol.affine), args.csf_out_fODF) diff --git a/scripts/scil_qball_metrics.py b/scripts/scil_qball_metrics.py index 817ffd565..e45db19cc 100755 --- a/scripts/scil_qball_metrics.py +++ b/scripts/scil_qball_metrics.py @@ -31,7 +31,8 @@ from scilpy.io.utils import (add_overwrite_arg, add_processes_arg, add_sh_basis_args, assert_inputs_exist, assert_outputs_exist, add_force_b0_arg, - validate_nbr_processes, add_verbose_arg) + validate_nbr_processes, add_verbose_arg, + interpret_sh_basis) from scilpy.io.image import get_data_as_mask from scilpy.gradients.bvec_bval_tools import (normalize_bvecs, is_normalized_bvecs, @@ -133,6 +134,7 @@ def main(): gtab = gradient_table(bvals, bvecs, b0_threshold=bvals.min()) sphere = get_sphere('symmetric724') + sh_basis, _ = interpret_sh_basis(args) mask = None if args.mask: @@ -149,6 +151,7 @@ def main(): model = CsaOdfModel(gtab, sh_order=args.sh_order, smooth=DEFAULT_SMOOTH) + # ToDo: Once Dipy adds the legacy option to peaks_from_model, put is_legacy odfpeaks = peaks_from_model(model=model, data=data, sphere=sphere, @@ -159,7 +162,7 @@ def main(): normalize_peaks=True, return_sh=True, sh_order=int(args.sh_order), - sh_basis_type=args.sh_basis, + sh_basis_type=sh_basis, npeaks=5, parallel=parallel, num_processes=nbr_processes) diff --git a/scripts/scil_sh_to_aodf.py b/scripts/scil_sh_to_aodf.py index e383b3426..967596162 100755 --- a/scripts/scil_sh_to_aodf.py +++ b/scripts/scil_sh_to_aodf.py @@ -26,7 +26,7 @@ from scilpy.reconst.utils import get_sh_order_and_fullness from scilpy.io.utils import (add_overwrite_arg, add_verbose_arg, assert_inputs_exist, add_sh_basis_args, - assert_outputs_exist) + assert_outputs_exist, interpret_sh_basis) from scilpy.denoise.asym_filtering import (cosine_filtering, angle_aware_bilateral_filtering) @@ -113,14 +113,16 @@ def main(): data = sh_img.get_fdata(dtype=np.float32) sh_order, full_basis = get_sh_order_and_fullness(data.shape[-1]) + sh_basis, is_legacy = interpret_sh_basis(args) t0 = time.perf_counter() logging.info('Filtering SH image.') if args.method == 'bilateral': asym_sh = angle_aware_bilateral_filtering( data, sh_order=sh_order, - sh_basis=args.sh_basis, + sh_basis=sh_basis, in_full_basis=full_basis, + is_legacy=is_legacy, sphere_str=args.sphere, sigma_spatial=args.sigma_spatial, sigma_angular=args.sigma_angular, @@ -131,6 +133,7 @@ def main(): data, sh_order=sh_order, sh_basis=args.sh_basis, in_full_basis=full_basis, + is_legacy=is_legacy, sphere_str=args.sphere, dot_sharpness=args.sharpness, sigma=args.sigma_spatial) diff --git a/scripts/scil_sh_to_sf.py b/scripts/scil_sh_to_sf.py index 8a55dcd1e..9ffff83ff 100755 --- a/scripts/scil_sh_to_sf.py +++ b/scripts/scil_sh_to_sf.py @@ -26,7 +26,8 @@ from scilpy.io.utils import (add_force_b0_arg, add_overwrite_arg, add_processes_arg, add_sh_basis_args, assert_inputs_exist, add_verbose_arg, - assert_outputs_exist, validate_nbr_processes) + assert_outputs_exist, interpret_sh_basis, + validate_nbr_processes) from scilpy.reconst.sh import convert_sh_to_sf from scilpy.gradients.bvec_bval_tools import (check_b0_threshold) @@ -107,6 +108,7 @@ def main(): parser.error("--in_b0 is required when using --b0_scaling.") nbr_processes = validate_nbr_processes(parser, args) + sh_basis, is_legacy = interpret_sh_basis(args) # Load SH vol_sh = nib.load(args.in_sh) @@ -123,8 +125,9 @@ def main(): sphere = Sphere(xyz=bvecs) sf = convert_sh_to_sf(data_sh, sphere, - input_basis=args.sh_basis, + input_basis=sh_basis, input_full_basis=args.full_basis, + is_input_legacy=is_legacy dtype=args.dtype, nbr_processes=nbr_processes) new_bvecs = sphere.vertices.astype(np.float32) diff --git a/scripts/scil_tracking_pft.py b/scripts/scil_tracking_pft.py index 58e2dc098..3e3a147ae 100755 --- a/scripts/scil_tracking_pft.py +++ b/scripts/scil_tracking_pft.py @@ -45,8 +45,8 @@ from scilpy.io.image import get_data_as_mask from scilpy.io.utils import (add_overwrite_arg, add_sh_basis_args, - add_verbose_arg, - assert_inputs_exist, assert_outputs_exist) + add_verbose_arg, assert_inputs_exist, + assert_outputs_exist, interpret_sh_basis) from scilpy.tracking.utils import get_theta @@ -199,7 +199,8 @@ def main(): if not np.allclose(np.linalg.norm(tracking_sphere.vertices, axis=1), 1.): raise RuntimeError('Tracking sphere should be unit normed.') - sh_basis = args.sh_basis + sh_basis, _ = interpret_sh_basis(args) + # ToDo: Once Dipy adds the legacy option to dgklass, put is_legacy if args.algo == 'det': dgklass = DeterministicMaximumDirectionGetter diff --git a/scripts/scil_tractogram_compute_TODI.py b/scripts/scil_tractogram_compute_TODI.py index 7612ffc1a..519f4de2f 100755 --- a/scripts/scil_tractogram_compute_TODI.py +++ b/scripts/scil_tractogram_compute_TODI.py @@ -20,7 +20,8 @@ from scilpy.io.streamlines import load_tractogram_with_reference from scilpy.io.utils import (add_overwrite_arg, add_reference_arg, add_sh_basis_args, add_verbose_arg, - assert_inputs_exist, assert_outputs_exist) + assert_inputs_exist, assert_outputs_exist, + interpret_sh_basis) from scilpy.tractanalysis.todi import TrackOrientationDensityImaging @@ -143,10 +144,12 @@ def main(): img.to_filename(args.out_mask) if args.out_todi_sh: + sh_basis, is_legacy = interpret_sh_basis(args) if args.normalize_per_voxel: todi_obj.normalize_todi_per_voxel() - img = todi_obj.get_sh(args.sh_basis, args.sh_order, - full_basis=args.asymmetric) + img = todi_obj.get_sh(sh_basis, args.sh_order, + full_basis=args.asymmetric, + is_legacy=is_legacy) img = todi_obj.reshape_to_3d(img) img = nib.Nifti1Image(img.astype(np.float32), affine) img.to_filename(args.out_todi_sh) diff --git a/scripts/scil_visualize_fodf.py b/scripts/scil_visualize_fodf.py index 422b0c843..06c09edcb 100755 --- a/scripts/scil_visualize_fodf.py +++ b/scripts/scil_visualize_fodf.py @@ -24,7 +24,7 @@ from scilpy.reconst.utils import get_sh_order_and_fullness from scilpy.io.utils import (add_sh_basis_args, add_overwrite_arg, assert_inputs_exist, add_verbose_arg, - assert_outputs_exist) + assert_outputs_exist, interpret_sh_basis) from scilpy.io.image import assert_same_resolution, get_data_as_mask from scilpy.viz.scene_utils import (create_odf_slicer, create_texture_slicer, create_peaks_slicer, create_scene, @@ -266,6 +266,7 @@ def main(): data = _get_data_from_inputs(args) sph = get_sphere(args.sphere) sh_order, full_basis = get_sh_order_and_fullness(data['fodf'].shape[-1]) + sh_basis, is_legacy = interpret_sh_basis(args) logging.getLogger().setLevel(logging.getLevelName(args.verbose)) actors = [] @@ -287,14 +288,15 @@ def main(): odf_actor, var_actor = create_odf_slicer(data['fodf'], args.axis_name, args.slice_index, mask, sph, args.sph_subdivide, sh_order, - args.sh_basis, full_basis, + sh_basis, full_basis, args.scale, not args.radial_scale_off, not args.norm_off, args.colormap or color_rgb, sh_variance=variance, variance_k=args.variance_k, - variance_color=var_color) + variance_color=var_color, + is_legacy=is_legacy) actors.append(odf_actor) # Instantiate a variance slicer actor if a variance image is supplied From 15810db13399dabc3c68843574426b30c19c746a Mon Sep 17 00:00:00 2001 From: Philippe Karan Date: Fri, 23 Feb 2024 13:44:55 -0500 Subject: [PATCH 04/16] Adding legacy to tracking_pft --- scripts/scil_tracking_pft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/scil_tracking_pft.py b/scripts/scil_tracking_pft.py index 3e3a147ae..ad96a7d99 100755 --- a/scripts/scil_tracking_pft.py +++ b/scripts/scil_tracking_pft.py @@ -199,8 +199,7 @@ def main(): if not np.allclose(np.linalg.norm(tracking_sphere.vertices, axis=1), 1.): raise RuntimeError('Tracking sphere should be unit normed.') - sh_basis, _ = interpret_sh_basis(args) - # ToDo: Once Dipy adds the legacy option to dgklass, put is_legacy + sh_basis, is_legacy = interpret_sh_basis(args) if args.algo == 'det': dgklass = DeterministicMaximumDirectionGetter @@ -218,6 +217,7 @@ def main(): max_angle=theta, sphere=tracking_sphere, basis_type=sh_basis, + legacy=is_legacy, pmf_threshold=args.sf_threshold, relative_peak_threshold=args.sf_threshold_init) From aeede100e1f14b897fa895e71c56b57b73052bcf Mon Sep 17 00:00:00 2001 From: Philippe Karan Date: Fri, 23 Feb 2024 13:49:44 -0500 Subject: [PATCH 05/16] Fixing typo --- scripts/scil_aodf_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/scil_aodf_metrics.py b/scripts/scil_aodf_metrics.py index e7d1a7419..3cd9c76c0 100755 --- a/scripts/scil_aodf_metrics.py +++ b/scripts/scil_aodf_metrics.py @@ -178,7 +178,7 @@ def main(): # the usual default value (5) of npeaks npeaks=10, sh_basis_type=sh_basis, - is_legacy=is_legacy + is_legacy=is_legacy, nbr_processes=args.nbr_processes, full_basis=full_basis, is_symmetric=False) From 0b79bb788005e2cf31d16804bcd6e95b02306b93 Mon Sep 17 00:00:00 2001 From: Philippe Karan Date: Fri, 23 Feb 2024 13:54:38 -0500 Subject: [PATCH 06/16] Fixing convert --- scripts/tests/test_sh_convert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/tests/test_sh_convert.py b/scripts/tests/test_sh_convert.py index a0e72d7d6..045f8dad4 100644 --- a/scripts/tests/test_sh_convert.py +++ b/scripts/tests/test_sh_convert.py @@ -22,5 +22,5 @@ def test_execution_processing(script_runner): 'fodf.nii.gz') ret = script_runner.run('scil_sh_convert.py', in_fodf, 'fodf_descoteaux07.nii.gz', 'tournier07', - '--processes', '1') + 'descoteaux07_legacy', '--processes', '1') assert ret.success From 23e283496ff830533bd27aac993ad4ba748c1ffc Mon Sep 17 00:00:00 2001 From: Philippe Karan Date: Fri, 23 Feb 2024 14:45:18 -0500 Subject: [PATCH 07/16] Fixing bug --- scilpy/denoise/tests/test_asym_filtering.py | 4 ++-- scilpy/io/utils.py | 2 +- scripts/scil_dwi_to_sh.py | 1 + 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/scilpy/denoise/tests/test_asym_filtering.py b/scilpy/denoise/tests/test_asym_filtering.py index b626a0ed8..a574bffa5 100644 --- a/scilpy/denoise/tests/test_asym_filtering.py +++ b/scilpy/denoise/tests/test_asym_filtering.py @@ -22,7 +22,7 @@ def test_angle_aware_bilateral_filtering(): sh_order, full_basis = get_sh_order_and_fullness(in_sh.shape[-1]) out = angle_aware_bilateral_filtering_cpu(in_sh, sh_order, - sh_basis, full_basis, + sh_basis, full_basis, True, sphere_str, sigma_spatial, sigma_angular, sigma_range) @@ -40,7 +40,7 @@ def test_cosine_filtering(): sharpness = 1.0 sh_order, full_basis = get_sh_order_and_fullness(in_sh.shape[-1]) - out = cosine_filtering(in_sh, sh_order, sh_basis, full_basis, + out = cosine_filtering(in_sh, sh_order, sh_basis, full_basis, True, sharpness, sphere_str, sigma_spatial) assert np.allclose(out, fodf_3x3_order8_descoteaux07_filtered_cosine) diff --git a/scilpy/io/utils.py b/scilpy/io/utils.py index 94fa487f9..29448660a 100644 --- a/scilpy/io/utils.py +++ b/scilpy/io/utils.py @@ -269,7 +269,7 @@ def add_sh_basis_args(parser, mandatory=False, input_output=False): 'required, in that order.' else: nargs = 1 - def_val = 'descoteaux07_legacy' + def_val = ['descoteaux07_legacy'] input_output_msg = '' choices = ['descoteaux07', 'tournier07', 'descoteaux07_legacy', diff --git a/scripts/scil_dwi_to_sh.py b/scripts/scil_dwi_to_sh.py index 408ad91bc..1f439282f 100755 --- a/scripts/scil_dwi_to_sh.py +++ b/scripts/scil_dwi_to_sh.py @@ -70,6 +70,7 @@ def main(): gtab = gradient_table(args.in_bval, args.in_bvec, b0_threshold=bvals.min()) sh_basis, is_legacy = interpret_sh_basis(args) + print(sh_basis) mask = None if args.mask: From 652f010dd6cf7e4472584a76b549f3e3d8c41ca2 Mon Sep 17 00:00:00 2001 From: Philippe Karan Date: Fri, 23 Feb 2024 16:33:17 -0500 Subject: [PATCH 08/16] Test --- scilpy/tracking/propagator.py | 8 ++++++-- scilpy/tracking/tracker.py | 7 +++++-- scilpy/tracking/utils.py | 10 ++++++---- scripts/scil_sh_to_aodf.py | 2 +- scripts/scil_tracking_local.py | 13 ++++++++----- scripts/scil_tracking_local_dev.py | 7 ++++--- scripts/tests/test_tracking_local.py | 2 +- 7 files changed, 31 insertions(+), 18 deletions(-) diff --git a/scilpy/tracking/propagator.py b/scilpy/tracking/propagator.py index ff0211a65..8fc2736f9 100644 --- a/scilpy/tracking/propagator.py +++ b/scilpy/tracking/propagator.py @@ -320,7 +320,8 @@ def __init__(self, datavolume, step_size, rk_order, algo, basis, sf_threshold, sf_threshold_init, theta, dipy_sphere='symmetric724', min_separation_angle=np.pi / 16., - space=Space('vox'), origin=Origin('center')): + space=Space('vox'), origin=Origin('center'), + is_legacy=True): """ Parameters @@ -363,6 +364,8 @@ def __init__(self, datavolume, step_size, Origin of the streamlines during tracking. Default: center, like in dipy. Interpolation of the ODF is done in center origin so this choice implies the less data modification. + is_legacy : bool, optional + Whether or not the SH basis is in its legacy form. """ super().__init__(datavolume, step_size, rk_order, dipy_sphere, space, origin) @@ -395,9 +398,10 @@ def __init__(self, datavolume, step_size, sh_order, full_basis =\ get_sh_order_and_fullness(self.datavolume.data.shape[-1]) self.basis = basis + self.is_legacy = is_legacy self.B = sh_to_sf_matrix(self.sphere, sh_order, self.basis, smooth=0.006, return_inv=False, - full_basis=full_basis) + full_basis=full_basis, legacy=self.is_legacy) def _get_sf(self, pos): """ diff --git a/scilpy/tracking/tracker.py b/scilpy/tracking/tracker.py index 208a33fc6..79aed3a26 100644 --- a/scilpy/tracking/tracker.py +++ b/scilpy/tracking/tracker.py @@ -503,6 +503,8 @@ class GPUTacker(): is randomly drawn from the list for each streamline. sh_basis : str, optional Spherical harmonics basis. + is_legacy : bool, optional + Whether or not the SH basis is in its legacy form. batch_size : int, optional Approximate size of GPU batches. forward_only: bool, optional @@ -514,7 +516,7 @@ class GPUTacker(): """ def __init__(self, sh, mask, seeds, step_size, max_nbr_pts, theta=20.0, sf_threshold=0.1, sh_interp='trilinear', - sh_basis='descoteaux07', batch_size=100000, + sh_basis='descoteaux07', is_legacy=True, batch_size=100000, forward_only=False, rng_seed=None, sphere=None): if not have_opencl: raise ImportError('pyopencl is not installed. In order to use' @@ -545,6 +547,7 @@ def __init__(self, sh, mask, seeds, step_size, max_nbr_pts, self.theta = np.atleast_1d(theta) self.sh_basis = sh_basis + self.is_legacy = is_legacy self.forward_only = forward_only # Instantiate random number generator @@ -602,7 +605,7 @@ def _track(self): sh_order = find_order_from_nb_coeff(self.sh) B_mat = sh_to_sf_matrix(self.sphere, sh_order, self.sh_basis, - return_inv=False) + return_inv=False, legacy=self.is_legacy) cl_manager.add_input_buffer(2, B_mat) fodf_max = self._get_max_amplitudes(B_mat) diff --git a/scilpy/tracking/utils.py b/scilpy/tracking/utils.py index 62438650a..759ef2265 100644 --- a/scilpy/tracking/utils.py +++ b/scilpy/tracking/utils.py @@ -252,7 +252,7 @@ def tracks_generator_wrapper(): def get_direction_getter(in_img, algo, sphere, sub_sphere, theta, sh_basis, - voxel_size, sf_threshold, sh_to_pmf): + voxel_size, sf_threshold, sh_to_pmf, is_legacy=True): """ Return the direction getter object. Parameters @@ -276,13 +276,14 @@ def get_direction_getter(in_img, algo, sphere, sub_sphere, theta, sh_basis, sh_to_pmf: bool Map sherical harmonics to spherical function (pmf) before tracking (faster, requires more memory). + is_legacy : bool, optional + Whether or not the SH basis is in its legacy form. Return ------ dg: dipy.direction.DirectionGetter The direction getter object. """ - img_data = nib.load(in_img).get_fdata(dtype=np.float32) sphere = HemiSphere.from_sphere( @@ -316,7 +317,7 @@ def get_direction_getter(in_img, algo, sphere, sub_sphere, theta, sh_basis, dg_class = ProbabilisticDirectionGetter return dg_class.from_shcoeff( shcoeff=img_data, max_angle=theta, sphere=sphere, - basis_type=sh_basis, sh_to_pmf=sh_to_pmf, + basis_type=sh_basis, legacy=is_legacy, sh_to_pmf=sh_to_pmf, relative_peak_threshold=sf_threshold, **kwargs) elif algo == 'eudx': # Code for algo EUDX. We don't use peaks_from_model @@ -355,7 +356,8 @@ def get_direction_getter(in_img, algo, sphere, sub_sphere, theta, sh_basis, peak_indices = np.full((img_shape_3d + (npeaks, )), -1, dtype='int') b_matrix = get_b_matrix( - find_order_from_nb_coeff(img_data), sphere, sh_basis) + find_order_from_nb_coeff(img_data), sphere, sh_basis, + legacy=is_legacy) for idx in np.argwhere(np.sum(img_data, axis=-1)): idx = tuple(idx) diff --git a/scripts/scil_sh_to_aodf.py b/scripts/scil_sh_to_aodf.py index 967596162..9560b8179 100755 --- a/scripts/scil_sh_to_aodf.py +++ b/scripts/scil_sh_to_aodf.py @@ -131,7 +131,7 @@ def main(): else: # args.method == 'cosine' asym_sh = cosine_filtering( data, sh_order=sh_order, - sh_basis=args.sh_basis, + sh_basis=sh_basis, in_full_basis=full_basis, is_legacy=is_legacy, sphere_str=args.sphere, diff --git a/scripts/scil_tracking_local.py b/scripts/scil_tracking_local.py index 7ab4b5631..5c0b4c8c5 100755 --- a/scripts/scil_tracking_local.py +++ b/scripts/scil_tracking_local.py @@ -69,7 +69,7 @@ from scilpy.io.image import get_data_as_mask from scilpy.io.utils import (add_sphere_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, - verify_compression_th) + interpret_sh_basis, verify_compression_th) from scilpy.tracking.utils import (add_mandatory_options_tracking, add_out_options, add_seeding_options, add_tracking_options, @@ -202,6 +202,8 @@ def main(): vox_step_size = args.step_size / voxel_size seed_img = nib.load(args.in_seed) + sh_basis, is_legacy = interpret_sh_basis(args) + if np.count_nonzero(seed_img.get_fdata(dtype=np.float32)) == 0: raise IOError('The image {} is empty. ' 'It can\'t be loaded as ' @@ -225,8 +227,9 @@ def main(): streamlines_generator = LocalTracking( get_direction_getter( args.in_odf, args.algo, args.sphere, - args.sub_sphere, args.theta, args.sh_basis, - voxel_size, args.sf_threshold, args.sh_to_pmf), + args.sub_sphere, args.theta, sh_basis, + voxel_size, args.sf_threshold, args.sh_to_pmf, + is_legacy=is_legacy), BinaryStoppingCriterion(mask_data), seeds, np.eye(4), step_size=vox_step_size, max_cross=1, @@ -252,7 +255,8 @@ def main(): theta=get_theta(args.theta, args.algo), sf_threshold=args.sf_threshold, sh_interp=sh_interp, - sh_basis=args.sh_basis, + sh_basis=sh_basis, + is_legacy=is_legacy, batch_size=batch_size, forward_only=forward_only, rng_seed=args.seed, @@ -263,7 +267,6 @@ def main(): odf_sh_img, total_nb_seeds, args.out_tractogram, args.min_length, args.max_length, args.compress, args.save_seeds, args.verbose) - # Final logging logging.info('Saved tractogram to {0}.'.format(args.out_tractogram)) diff --git a/scripts/scil_tracking_local_dev.py b/scripts/scil_tracking_local_dev.py index 37bcdcf86..0bb19629a 100755 --- a/scripts/scil_tracking_local_dev.py +++ b/scripts/scil_tracking_local_dev.py @@ -59,7 +59,7 @@ from scilpy.io.utils import (add_processes_arg, add_sphere_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, - verify_compression_th) + interpret_sh_basis, verify_compression_th) from scilpy.image.volume_space_management import DataVolume from scilpy.tracking.propagator import ODFPropagator from scilpy.tracking.seed import SeedGenerator @@ -237,10 +237,11 @@ def main(): # Using space and origin in the propagator: vox and center, like # in dipy. + sh_basis, is_legacy = interpret_sh_basis(args) propagator = ODFPropagator( - dataset, vox_step_size, args.rk_order, args.algo, args.sh_basis, + dataset, vox_step_size, args.rk_order, args.algo, sh_basis, args.sf_threshold, args.sf_threshold_init, theta, args.sphere, - space=our_space, origin=our_origin) + space=our_space, origin=our_origin, is_legacy=is_legacy) logging.info("Instantiating tracker.") tracker = Tracker(propagator, mask, seed_generator, nbr_seeds, min_nbr_pts, diff --git a/scripts/tests/test_tracking_local.py b/scripts/tests/test_tracking_local.py index 4e0525fbe..88d150fde 100644 --- a/scripts/tests/test_tracking_local.py +++ b/scripts/tests/test_tracking_local.py @@ -167,5 +167,5 @@ def test_execution_tracking_fodf_prob_pmf_mapping(script_runner): in_mask, in_mask, 'local_prob3.trk', '--nt', '100', '--compress', '0.1', '--sh_basis', 'descoteaux07', '--min_length', '20', '--max_length', '200', - '--sh_to_pmf') + '--sh_to_pmf', '-v') assert ret.success From b9036c5fd556ece083a1aa8633d09557dee07d30 Mon Sep 17 00:00:00 2001 From: Philippe Karan Date: Fri, 23 Feb 2024 16:47:38 -0500 Subject: [PATCH 09/16] Fixes --- scripts/scil_sh_to_sf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/scil_sh_to_sf.py b/scripts/scil_sh_to_sf.py index 9ffff83ff..15eb5841d 100755 --- a/scripts/scil_sh_to_sf.py +++ b/scripts/scil_sh_to_sf.py @@ -127,7 +127,7 @@ def main(): sf = convert_sh_to_sf(data_sh, sphere, input_basis=sh_basis, input_full_basis=args.full_basis, - is_input_legacy=is_legacy + is_input_legacy=is_legacy, dtype=args.dtype, nbr_processes=nbr_processes) new_bvecs = sphere.vertices.astype(np.float32) From 962899945109231925d10a0341f455d32d2423c6 Mon Sep 17 00:00:00 2001 From: Philippe Karan Date: Mon, 26 Feb 2024 15:38:34 -0500 Subject: [PATCH 10/16] Adressing Charles comments --- scilpy/io/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scilpy/io/utils.py b/scilpy/io/utils.py index 83b1625e7..3827cccbd 100644 --- a/scilpy/io/utils.py +++ b/scilpy/io/utils.py @@ -292,15 +292,15 @@ def add_sh_basis_args(parser, mandatory=False, input_output=False): help=help_msg) -def interpret_sh_basis(args): +def parse_sh_basis_arg(args): """ - Interpret the input from args.sh_basis. If two SH bases are given, + Parser the input from args.sh_basis. If two SH bases are given, both input/output sh_basis and is_legacy are returned. Parameters ---------- - args : argparser - Argparser from a script. + args : ArgumentParser.parse_args + ArgumentParser.parse_args from a script. Returns ------- From 7cd21ee80b0fc96c1e669f04e44e8954116d6c50 Mon Sep 17 00:00:00 2001 From: Philippe Karan Date: Tue, 27 Feb 2024 08:00:53 -0500 Subject: [PATCH 11/16] Fix naming --- scripts/scil_aodf_metrics.py | 4 ++-- scripts/scil_bundle_generate_priors.py | 4 ++-- scripts/scil_bundle_mean_fixel_afd.py | 4 ++-- scripts/scil_bundle_mean_fixel_afd_from_hdf5.py | 4 ++-- scripts/scil_dwi_to_sh.py | 4 ++-- scripts/scil_fodf_max_in_ventricles.py | 4 ++-- scripts/scil_fodf_memsmt.py | 4 ++-- scripts/scil_fodf_metrics.py | 4 ++-- scripts/scil_fodf_msmt.py | 4 ++-- scripts/scil_fodf_ssst.py | 4 ++-- scripts/scil_qball_metrics.py | 4 ++-- scripts/scil_sh_convert.py | 4 ++-- scripts/scil_sh_to_aodf.py | 4 ++-- scripts/scil_sh_to_sf.py | 4 ++-- scripts/scil_tracking_local.py | 4 ++-- scripts/scil_tracking_local_dev.py | 4 ++-- scripts/scil_tracking_pft.py | 4 ++-- scripts/scil_tractogram_compute_TODI.py | 4 ++-- scripts/scil_visualize_fodf.py | 4 ++-- 19 files changed, 38 insertions(+), 38 deletions(-) diff --git a/scripts/scil_aodf_metrics.py b/scripts/scil_aodf_metrics.py index a1faaab58..d6425cfcd 100755 --- a/scripts/scil_aodf_metrics.py +++ b/scripts/scil_aodf_metrics.py @@ -45,7 +45,7 @@ assert_inputs_exist, assert_outputs_exist, add_overwrite_arg, - interpret_sh_basis) + parse_sh_basis_arg) from scilpy.io.image import get_data_as_mask @@ -149,7 +149,7 @@ def main(): sphere = get_sphere(args.sphere) - sh_basis, is_legacy = interpret_sh_basis(args) + sh_basis, is_legacy = parse_sh_basis_arg(args) sh_order, full_basis = get_sh_order_and_fullness(sh.shape[-1]) if not full_basis and (args.asi_map or args.odd_power_map): parser.error('Invalid SH image. A full SH basis is expected.') diff --git a/scripts/scil_bundle_generate_priors.py b/scripts/scil_bundle_generate_priors.py index 26ac79297..18b4c64fb 100755 --- a/scripts/scil_bundle_generate_priors.py +++ b/scripts/scil_bundle_generate_priors.py @@ -26,7 +26,7 @@ add_verbose_arg, assert_inputs_exist, assert_outputs_exist, - interpret_sh_basis) + parse_sh_basis_arg) from scilpy.reconst.utils import find_order_from_nb_coeff from scilpy.tractanalysis.todi import TrackOrientationDensityImaging @@ -98,7 +98,7 @@ def main(): img_sh = nib.load(args.in_fodf) sh_shape = img_sh.shape sh_order = find_order_from_nb_coeff(sh_shape) - sh_basis, is_legacy = interpret_sh_basis(args) + sh_basis, is_legacy = parse_sh_basis_arg(args) img_mask = nib.load(args.in_mask) sft = load_tractogram_with_reference(parser, args, args.in_bundle) diff --git a/scripts/scil_bundle_mean_fixel_afd.py b/scripts/scil_bundle_mean_fixel_afd.py index 462a2860d..8cdd859a9 100755 --- a/scripts/scil_bundle_mean_fixel_afd.py +++ b/scripts/scil_bundle_mean_fixel_afd.py @@ -23,7 +23,7 @@ from scilpy.io.utils import (add_overwrite_arg, add_sh_basis_args, add_reference_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, - interpret_sh_basis) + parse_sh_basis_arg) from scilpy.tractanalysis.afd_along_streamlines \ import afd_map_along_streamlines @@ -69,7 +69,7 @@ def main(): sft = load_tractogram_with_reference(parser, args, args.in_bundle) fodf_img = nib.load(args.in_fodf) - sh_basis, is_legacy = interpret_sh_basis(args) + sh_basis, is_legacy = parse_sh_basis_arg(args) afd_mean_map, rd_mean_map = afd_map_along_streamlines( sft, diff --git a/scripts/scil_bundle_mean_fixel_afd_from_hdf5.py b/scripts/scil_bundle_mean_fixel_afd_from_hdf5.py index ce282c79c..a10ef4c2a 100755 --- a/scripts/scil_bundle_mean_fixel_afd_from_hdf5.py +++ b/scripts/scil_bundle_mean_fixel_afd_from_hdf5.py @@ -33,7 +33,7 @@ add_verbose_arg, assert_inputs_exist, assert_outputs_exist, - interpret_sh_basis, + parse_sh_basis_arg, validate_nbr_processes) from scilpy.tractanalysis.afd_along_streamlines \ import afd_map_along_streamlines @@ -107,7 +107,7 @@ def main(): assert_outputs_exist(parser, args, [args.out_hdf5]) nbr_cpu = validate_nbr_processes(parser, args) - sh_basis, is_legacy = interpret_sh_basis(args) + sh_basis, is_legacy = parse_sh_basis_arg(args) # HDF5 will not overwrite the file if os.path.isfile(args.out_hdf5): diff --git a/scripts/scil_dwi_to_sh.py b/scripts/scil_dwi_to_sh.py index 1f439282f..e19ab90cc 100755 --- a/scripts/scil_dwi_to_sh.py +++ b/scripts/scil_dwi_to_sh.py @@ -19,7 +19,7 @@ from scilpy.io.utils import (add_force_b0_arg, add_overwrite_arg, add_sh_basis_args, assert_inputs_exist, add_verbose_arg, assert_outputs_exist, - interpret_sh_basis) + parse_sh_basis_arg) from scilpy.reconst.sh import compute_sh_coefficients @@ -69,7 +69,7 @@ def main(): bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec) gtab = gradient_table(args.in_bval, args.in_bvec, b0_threshold=bvals.min()) - sh_basis, is_legacy = interpret_sh_basis(args) + sh_basis, is_legacy = parse_sh_basis_arg(args) print(sh_basis) mask = None diff --git a/scripts/scil_fodf_max_in_ventricles.py b/scripts/scil_fodf_max_in_ventricles.py index 069c0de79..b4b4aef58 100755 --- a/scripts/scil_fodf_max_in_ventricles.py +++ b/scripts/scil_fodf_max_in_ventricles.py @@ -19,7 +19,7 @@ from scilpy.io.utils import (add_overwrite_arg, add_sh_basis_args, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, - interpret_sh_basis) + parse_sh_basis_arg) from scilpy.reconst.fodf import get_ventricles_max_fodf EPILOG = """ @@ -89,7 +89,7 @@ def main(): img_md = nib.load(args.in_md) md = img_md.get_fdata(dtype=np.float32) - sh_basis, is_legacy = interpret_sh_basis(args) + sh_basis, is_legacy = parse_sh_basis_arg(args) value, mask = get_ventricles_max_fodf(fodf, fa, md, zoom, sh_basis, args, is_legacy=is_legacy) diff --git a/scripts/scil_fodf_memsmt.py b/scripts/scil_fodf_memsmt.py index b372aaccf..99695eac6 100755 --- a/scripts/scil_fodf_memsmt.py +++ b/scripts/scil_fodf_memsmt.py @@ -52,7 +52,7 @@ from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist, assert_outputs_exist, add_sh_basis_args, add_processes_arg, add_verbose_arg, - interpret_sh_basis) + parse_sh_basis_arg) from scilpy.reconst.fodf import fit_from_model from scilpy.reconst.sh import convert_sh_basis @@ -178,7 +178,7 @@ def main(): raise ValueError("Mask is not the same shape as data.") sh_order = args.sh_order - sh_basis, is_legacy = interpret_sh_basis(args) + sh_basis, is_legacy = parse_sh_basis_arg(args) # Checking data and sh_order if data.shape[-1] < (sh_order + 1) * (sh_order + 2) / 2: diff --git a/scripts/scil_fodf_metrics.py b/scripts/scil_fodf_metrics.py index 163ab5f63..195922b85 100755 --- a/scripts/scil_fodf_metrics.py +++ b/scripts/scil_fodf_metrics.py @@ -45,7 +45,7 @@ from scilpy.io.utils import (add_overwrite_arg, add_sh_basis_args, add_processes_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, - interpret_sh_basis) + parse_sh_basis_arg) from scilpy.reconst.sh import peaks_from_sh, maps_from_sh @@ -147,7 +147,7 @@ def main(): raise ValueError("Mask is not the same shape as data.") sphere = get_sphere(args.sphere) - sh_basis, is_legacy = interpret_sh_basis(args) + sh_basis, is_legacy = parse_sh_basis_arg(args) # Computing peaks peak_dirs, peak_values, \ diff --git a/scripts/scil_fodf_msmt.py b/scripts/scil_fodf_msmt.py index c8e68d5dc..4e40be570 100755 --- a/scripts/scil_fodf_msmt.py +++ b/scripts/scil_fodf_msmt.py @@ -36,7 +36,7 @@ from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist, assert_outputs_exist, add_force_b0_arg, add_sh_basis_args, add_processes_arg, - add_verbose_arg, interpret_sh_basis) + add_verbose_arg, parse_sh_basis_arg) from scilpy.reconst.fodf import fit_from_model from scilpy.reconst.sh import convert_sh_basis @@ -144,7 +144,7 @@ def main(): tol = args.tolerance sh_order = args.sh_order - sh_basis, is_legacy = interpret_sh_basis(args) + sh_basis, is_legacy = parse_sh_basis_arg(args) # Checking data and sh_order b0_thr = check_b0_threshold( diff --git a/scripts/scil_fodf_ssst.py b/scripts/scil_fodf_ssst.py index 8907009cc..f58953290 100755 --- a/scripts/scil_fodf_ssst.py +++ b/scripts/scil_fodf_ssst.py @@ -23,7 +23,7 @@ normalize_bvecs, is_normalized_bvecs) from scilpy.io.image import get_data_as_mask -from scilpy.io.utils import (add_overwrite_arg, interpret_sh_basis, +from scilpy.io.utils import (add_overwrite_arg, parse_sh_basis_arg, assert_inputs_exist, add_verbose_arg, assert_outputs_exist, add_force_b0_arg, add_sh_basis_args, add_processes_arg) @@ -87,7 +87,7 @@ def main(): raise ValueError("Mask is not the same shape as data.") sh_order = args.sh_order - sh_basis, is_legacy = interpret_sh_basis(args) + sh_basis, is_legacy = parse_sh_basis_arg(args) # Checking data and sh_order b0_thr = check_b0_threshold( diff --git a/scripts/scil_qball_metrics.py b/scripts/scil_qball_metrics.py index e45db19cc..cd19b29e0 100755 --- a/scripts/scil_qball_metrics.py +++ b/scripts/scil_qball_metrics.py @@ -32,7 +32,7 @@ add_sh_basis_args, assert_inputs_exist, assert_outputs_exist, add_force_b0_arg, validate_nbr_processes, add_verbose_arg, - interpret_sh_basis) + parse_sh_basis_arg) from scilpy.io.image import get_data_as_mask from scilpy.gradients.bvec_bval_tools import (normalize_bvecs, is_normalized_bvecs, @@ -134,7 +134,7 @@ def main(): gtab = gradient_table(bvals, bvecs, b0_threshold=bvals.min()) sphere = get_sphere('symmetric724') - sh_basis, _ = interpret_sh_basis(args) + sh_basis, _ = parse_sh_basis_arg(args) mask = None if args.mask: diff --git a/scripts/scil_sh_convert.py b/scripts/scil_sh_convert.py index 423b0cb9d..c2dc81c0e 100755 --- a/scripts/scil_sh_convert.py +++ b/scripts/scil_sh_convert.py @@ -22,7 +22,7 @@ from scilpy.io.utils import (add_overwrite_arg, add_sh_basis_args, add_processes_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, - interpret_sh_basis) + parse_sh_basis_arg) def _build_arg_parser(): @@ -55,7 +55,7 @@ def main(): data = img.get_fdata(dtype=np.float32) in_sh_basis, is_in_legacy, out_sh_basis, is_out_legacy \ - = interpret_sh_basis(args) + = parse_sh_basis_arg(args) new_data = convert_sh_basis(data, sphere, input_basis=in_sh_basis, diff --git a/scripts/scil_sh_to_aodf.py b/scripts/scil_sh_to_aodf.py index 9560b8179..a3a31d5ca 100755 --- a/scripts/scil_sh_to_aodf.py +++ b/scripts/scil_sh_to_aodf.py @@ -26,7 +26,7 @@ from scilpy.reconst.utils import get_sh_order_and_fullness from scilpy.io.utils import (add_overwrite_arg, add_verbose_arg, assert_inputs_exist, add_sh_basis_args, - assert_outputs_exist, interpret_sh_basis) + assert_outputs_exist, parse_sh_basis_arg) from scilpy.denoise.asym_filtering import (cosine_filtering, angle_aware_bilateral_filtering) @@ -113,7 +113,7 @@ def main(): data = sh_img.get_fdata(dtype=np.float32) sh_order, full_basis = get_sh_order_and_fullness(data.shape[-1]) - sh_basis, is_legacy = interpret_sh_basis(args) + sh_basis, is_legacy = parse_sh_basis_arg(args) t0 = time.perf_counter() logging.info('Filtering SH image.') diff --git a/scripts/scil_sh_to_sf.py b/scripts/scil_sh_to_sf.py index 15eb5841d..cc47be01e 100755 --- a/scripts/scil_sh_to_sf.py +++ b/scripts/scil_sh_to_sf.py @@ -26,7 +26,7 @@ from scilpy.io.utils import (add_force_b0_arg, add_overwrite_arg, add_processes_arg, add_sh_basis_args, assert_inputs_exist, add_verbose_arg, - assert_outputs_exist, interpret_sh_basis, + assert_outputs_exist, parse_sh_basis_arg, validate_nbr_processes) from scilpy.reconst.sh import convert_sh_to_sf from scilpy.gradients.bvec_bval_tools import (check_b0_threshold) @@ -108,7 +108,7 @@ def main(): parser.error("--in_b0 is required when using --b0_scaling.") nbr_processes = validate_nbr_processes(parser, args) - sh_basis, is_legacy = interpret_sh_basis(args) + sh_basis, is_legacy = parse_sh_basis_arg(args) # Load SH vol_sh = nib.load(args.in_sh) diff --git a/scripts/scil_tracking_local.py b/scripts/scil_tracking_local.py index 5c0b4c8c5..b40a70eba 100755 --- a/scripts/scil_tracking_local.py +++ b/scripts/scil_tracking_local.py @@ -69,7 +69,7 @@ from scilpy.io.image import get_data_as_mask from scilpy.io.utils import (add_sphere_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, - interpret_sh_basis, verify_compression_th) + parse_sh_basis_arg, verify_compression_th) from scilpy.tracking.utils import (add_mandatory_options_tracking, add_out_options, add_seeding_options, add_tracking_options, @@ -202,7 +202,7 @@ def main(): vox_step_size = args.step_size / voxel_size seed_img = nib.load(args.in_seed) - sh_basis, is_legacy = interpret_sh_basis(args) + sh_basis, is_legacy = parse_sh_basis_arg(args) if np.count_nonzero(seed_img.get_fdata(dtype=np.float32)) == 0: raise IOError('The image {} is empty. ' diff --git a/scripts/scil_tracking_local_dev.py b/scripts/scil_tracking_local_dev.py index 0bb19629a..892d50143 100755 --- a/scripts/scil_tracking_local_dev.py +++ b/scripts/scil_tracking_local_dev.py @@ -59,7 +59,7 @@ from scilpy.io.utils import (add_processes_arg, add_sphere_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, - interpret_sh_basis, verify_compression_th) + parse_sh_basis_arg, verify_compression_th) from scilpy.image.volume_space_management import DataVolume from scilpy.tracking.propagator import ODFPropagator from scilpy.tracking.seed import SeedGenerator @@ -237,7 +237,7 @@ def main(): # Using space and origin in the propagator: vox and center, like # in dipy. - sh_basis, is_legacy = interpret_sh_basis(args) + sh_basis, is_legacy = parse_sh_basis_arg(args) propagator = ODFPropagator( dataset, vox_step_size, args.rk_order, args.algo, sh_basis, args.sf_threshold, args.sf_threshold_init, theta, args.sphere, diff --git a/scripts/scil_tracking_pft.py b/scripts/scil_tracking_pft.py index ad96a7d99..803fed776 100755 --- a/scripts/scil_tracking_pft.py +++ b/scripts/scil_tracking_pft.py @@ -46,7 +46,7 @@ from scilpy.io.image import get_data_as_mask from scilpy.io.utils import (add_overwrite_arg, add_sh_basis_args, add_verbose_arg, assert_inputs_exist, - assert_outputs_exist, interpret_sh_basis) + assert_outputs_exist, parse_sh_basis_arg) from scilpy.tracking.utils import get_theta @@ -199,7 +199,7 @@ def main(): if not np.allclose(np.linalg.norm(tracking_sphere.vertices, axis=1), 1.): raise RuntimeError('Tracking sphere should be unit normed.') - sh_basis, is_legacy = interpret_sh_basis(args) + sh_basis, is_legacy = parse_sh_basis_arg(args) if args.algo == 'det': dgklass = DeterministicMaximumDirectionGetter diff --git a/scripts/scil_tractogram_compute_TODI.py b/scripts/scil_tractogram_compute_TODI.py index 519f4de2f..9286b8c86 100755 --- a/scripts/scil_tractogram_compute_TODI.py +++ b/scripts/scil_tractogram_compute_TODI.py @@ -21,7 +21,7 @@ from scilpy.io.utils import (add_overwrite_arg, add_reference_arg, add_sh_basis_args, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, - interpret_sh_basis) + parse_sh_basis_arg) from scilpy.tractanalysis.todi import TrackOrientationDensityImaging @@ -144,7 +144,7 @@ def main(): img.to_filename(args.out_mask) if args.out_todi_sh: - sh_basis, is_legacy = interpret_sh_basis(args) + sh_basis, is_legacy = parse_sh_basis_arg(args) if args.normalize_per_voxel: todi_obj.normalize_todi_per_voxel() img = todi_obj.get_sh(sh_basis, args.sh_order, diff --git a/scripts/scil_visualize_fodf.py b/scripts/scil_visualize_fodf.py index 06c09edcb..bcc79e9ce 100755 --- a/scripts/scil_visualize_fodf.py +++ b/scripts/scil_visualize_fodf.py @@ -24,7 +24,7 @@ from scilpy.reconst.utils import get_sh_order_and_fullness from scilpy.io.utils import (add_sh_basis_args, add_overwrite_arg, assert_inputs_exist, add_verbose_arg, - assert_outputs_exist, interpret_sh_basis) + assert_outputs_exist, parse_sh_basis_arg) from scilpy.io.image import assert_same_resolution, get_data_as_mask from scilpy.viz.scene_utils import (create_odf_slicer, create_texture_slicer, create_peaks_slicer, create_scene, @@ -266,7 +266,7 @@ def main(): data = _get_data_from_inputs(args) sph = get_sphere(args.sphere) sh_order, full_basis = get_sh_order_and_fullness(data['fodf'].shape[-1]) - sh_basis, is_legacy = interpret_sh_basis(args) + sh_basis, is_legacy = parse_sh_basis_arg(args) logging.getLogger().setLevel(logging.getLevelName(args.verbose)) actors = [] From 17d477088124ba161de824c666aa097489cbfeab Mon Sep 17 00:00:00 2001 From: karp2601 Date: Wed, 28 Feb 2024 15:07:09 -0500 Subject: [PATCH 12/16] Charles and Arnaud's comments --- scilpy/io/utils.py | 21 ++++++++++++--------- scripts/scil_dwi_to_sh.py | 1 - scripts/scil_sh_convert.py | 2 +- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/scilpy/io/utils.py b/scilpy/io/utils.py index 3827cccbd..c08b4ea76 100644 --- a/scilpy/io/utils.py +++ b/scilpy/io/utils.py @@ -242,7 +242,8 @@ def add_bbox_arg(parser): def add_sh_basis_args(parser, mandatory=False, input_output=False): """ - Add spherical harmonics (SH) bases argument. + Add spherical harmonics (SH) bases argument. For more information about + the bases, see https://docs.dipy.org/stable/theory/sh_basis.html. Parameters ---------- @@ -308,16 +309,18 @@ def parse_sh_basis_arg(args): Spherical harmonic basis name. is_legacy : bool Whether or not the SH basis is in its legacy form. - """ + """ + sh_basis_name = args.sh_basis[0] + sh_basis = 'descoteaux07' if 'descoteaux07' in sh_basis_name \ + else 'tournier07' + is_legacy = 'legacy' in sh_basis_name if len(args.sh_basis) == 2: - in_sh_basis = args.sh_basis[0].split("_")[0] - is_in_legacy = len(args.sh_basis[0].split("_")) == 2 - out_sh_basis = args.sh_basis[1].split("_")[0] - is_out_legacy = len(args.sh_basis[1].split("_")) == 2 - return in_sh_basis, is_in_legacy, out_sh_basis, is_out_legacy + sh_basis_name = args.sh_basis[1] + out_sh_basis = 'descoteaux07' if 'descoteaux07' in sh_basis_name \ + else 'tournier07' + is_out_legacy = 'legacy' in sh_basis_name + return sh_basis, is_legacy, out_sh_basis, is_out_legacy else: - sh_basis = args.sh_basis[0].split("_")[0] - is_legacy = len(args.sh_basis[0].split("_")) == 2 return sh_basis, is_legacy diff --git a/scripts/scil_dwi_to_sh.py b/scripts/scil_dwi_to_sh.py index e19ab90cc..c1947d3f0 100755 --- a/scripts/scil_dwi_to_sh.py +++ b/scripts/scil_dwi_to_sh.py @@ -70,7 +70,6 @@ def main(): gtab = gradient_table(args.in_bval, args.in_bvec, b0_threshold=bvals.min()) sh_basis, is_legacy = parse_sh_basis_arg(args) - print(sh_basis) mask = None if args.mask: diff --git a/scripts/scil_sh_convert.py b/scripts/scil_sh_convert.py index c2dc81c0e..5cd192521 100755 --- a/scripts/scil_sh_convert.py +++ b/scripts/scil_sh_convert.py @@ -6,7 +6,7 @@ 'descoteaux07', 'descoteaux07_legacy', 'tournier07' or 'tournier07_legacy'. Using the sh_basis argument, both the input and the output SH bases must be given, in the order. For more information about the bases, see -https://dipy.org/documentation/1.4.0./theory/sh_basis/. +https://docs.dipy.org/stable/theory/sh_basis.html. Formerly: scil_convert_sh_basis.py """ From 5d05e2f4656d05648e2d3fae9d3c9504ed373c66 Mon Sep 17 00:00:00 2001 From: karp2601 Date: Wed, 28 Feb 2024 16:18:31 -0500 Subject: [PATCH 13/16] Fix pep8 --- scilpy/io/utils.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/scilpy/io/utils.py b/scilpy/io/utils.py index c08b4ea76..6bf566ee1 100644 --- a/scilpy/io/utils.py +++ b/scilpy/io/utils.py @@ -273,15 +273,20 @@ def add_sh_basis_args(parser, mandatory=False, input_output=False): '\nMust be either \'descoteaux07\', \'tournier07\', \n' +\ '\'descoteaux07_legacy\' or \'tournier07_legacy\'' +\ ' [%(default)s]:\n' +\ - ' \'descoteaux07\' : SH basis from the Descoteaux et al.\n' +\ - ' MRM 2007 paper\n' +\ - ' \'tournier07\' : SH basis from the new Tournier et al.\n' +\ - ' NeuroImage 2019 paper, as in MRtrix 3.\n' +\ - ' \'descoteaux07_legacy\': SH basis from the legacy Dipy\n' +\ - ' implementation of the Descoteaux et al.\n' +\ - ' MRM 2007 paper\n' +\ - ' \'tournier07_legacy\' : SH basis from the legacy Tournier et al.\n' +\ - ' NeuroImage 2007 paper.' + ' \'descoteaux07\' : SH basis from the Descoteaux ' +\ + 'et al.\n' +\ + ' MRM 2007 paper\n' +\ + ' \'tournier07\' : SH basis from the new ' +\ + 'Tournier et al.\n' +\ + ' NeuroImage 2019 paper, as in ' +\ + 'MRtrix 3.\n' +\ + ' \'descoteaux07_legacy\': SH basis from the legacy Dipy ' +\ + 'implementation\n' +\ + ' of the ' +\ + 'Descoteaux et al. MRM 2007 paper\n' +\ + ' \'tournier07_legacy\' : SH basis from the legacy ' +\ + 'Tournier et al.\n' +\ + ' NeuroImage 2007 paper.' if mandatory: arg_name = 'sh_basis' @@ -309,15 +314,15 @@ def parse_sh_basis_arg(args): Spherical harmonic basis name. is_legacy : bool Whether or not the SH basis is in its legacy form. - """ + """ sh_basis_name = args.sh_basis[0] sh_basis = 'descoteaux07' if 'descoteaux07' in sh_basis_name \ - else 'tournier07' + else 'tournier07' is_legacy = 'legacy' in sh_basis_name if len(args.sh_basis) == 2: sh_basis_name = args.sh_basis[1] out_sh_basis = 'descoteaux07' if 'descoteaux07' in sh_basis_name \ - else 'tournier07' + else 'tournier07' is_out_legacy = 'legacy' in sh_basis_name return sh_basis, is_legacy, out_sh_basis, is_out_legacy else: From d7e24563a0bab5d472f215a75af3f1160346a3af Mon Sep 17 00:00:00 2001 From: karp2601 Date: Wed, 28 Feb 2024 16:57:36 -0500 Subject: [PATCH 14/16] Removing get_b_matrix function --- scilpy/reconst/fodf.py | 7 +++-- scilpy/reconst/tests/test_utils.py | 5 ---- scilpy/reconst/utils.py | 28 ------------------- scilpy/tracking/utils.py | 12 ++++---- scilpy/tractanalysis/afd_along_streamlines.py | 9 +++--- 5 files changed, 15 insertions(+), 46 deletions(-) diff --git a/scilpy/reconst/fodf.py b/scilpy/reconst/fodf.py index f683d27e9..0afe48915 100644 --- a/scilpy/reconst/fodf.py +++ b/scilpy/reconst/fodf.py @@ -7,8 +7,9 @@ from dipy.data import get_sphere from dipy.reconst.mcsd import MSDeconvFit from dipy.reconst.multi_voxel import MultiVoxelFit +from dipy.reconst.shm import sh_to_sf_matrix -from scilpy.reconst.utils import find_order_from_nb_coeff, get_b_matrix +from scilpy.reconst.utils import find_order_from_nb_coeff from dipy.utils.optpkg import optional_package cvx, have_cvxpy, _ = optional_package("cvxpy") @@ -47,7 +48,7 @@ def get_ventricles_max_fodf(data, fa, md, zoom, sh_basis, args, order = find_order_from_nb_coeff(data) sphere = get_sphere('repulsion100') - b_matrix = get_b_matrix(order, sphere, sh_basis, is_legacy=is_legacy) + b_matrix, _ = sh_to_sf_matrix(sphere, order, sh_basis, legacy=is_legacy) sum_of_max = 0 count = 0 @@ -91,7 +92,7 @@ def get_ventricles_max_fodf(data, fa, md, zoom, sh_basis, args, continue if fa[i, j, k] < args.fa_threshold \ and md[i, j, k] > args.md_threshold: - sf = np.dot(data[i, j, k], b_matrix.T) + sf = np.dot(data[i, j, k], b_matrix) sum_of_max += sf.max() count += 1 mask[i, j, k] = 1 diff --git a/scilpy/reconst/tests/test_utils.py b/scilpy/reconst/tests/test_utils.py index 53bd10779..a3e307761 100644 --- a/scilpy/reconst/tests/test_utils.py +++ b/scilpy/reconst/tests/test_utils.py @@ -11,11 +11,6 @@ def test_sh_basis(): pass -def test_get_b_matrix(): - # toDO - pass - - def test_get_maximas(): # toDO pass diff --git a/scilpy/reconst/utils.py b/scilpy/reconst/utils.py index 62abbb7e2..9c6dba5cc 100644 --- a/scilpy/reconst/utils.py +++ b/scilpy/reconst/utils.py @@ -31,34 +31,6 @@ def get_sh_order_and_fullness(ncoeffs): raise ValueError('Invalid number of coefficients for SH basis.') -def _honor_authorsnames_sh_basis(sh_basis_type): - sh_basis = sh_basis_type - if sh_basis_type == 'fibernav': - sh_basis = 'descoteaux07' - warnings.warn("'fibernav' sph basis name is deprecated and will be " - "discontinued in favor of 'descoteaux07'.", - DeprecationWarning) - elif sh_basis_type == 'mrtrix': - sh_basis = 'tournier07' - warnings.warn("'mrtrix' sph basis name is deprecated and will be " - "discontinued in favor of 'tournier07'.", - DeprecationWarning) - return sh_basis - - -def get_b_matrix(order, sphere, sh_basis_type, return_all=False, - is_legacy=True): - sh_basis = _honor_authorsnames_sh_basis(sh_basis_type) - sph_harm_basis = sph_harm_lookup.get(sh_basis) - if sph_harm_basis is None: - raise ValueError("Invalid basis name.") - b_matrix, m, n = sph_harm_basis(order, sphere.theta, sphere.phi, - legacy=is_legacy) - if return_all: - return b_matrix, m, n - return b_matrix - - def get_maximas(data, sphere, b_matrix, threshold, absolute_threshold, min_separation_angle=25): spherical_func = np.dot(data, b_matrix.T) diff --git a/scilpy/tracking/utils.py b/scilpy/tracking/utils.py index 759ef2265..4f79a7ebd 100644 --- a/scilpy/tracking/utils.py +++ b/scilpy/tracking/utils.py @@ -14,9 +14,9 @@ from dipy.direction.peaks import PeaksAndMetrics from dipy.io.utils import (get_reference_info, create_tractogram_header) +from dipy.reconst.shm import sh_to_sf_matrix from dipy.tracking.streamlinespeed import length, compress_streamlines -from scilpy.reconst.utils import (find_order_from_nb_coeff, - get_b_matrix, get_maximas) +from scilpy.reconst.utils import find_order_from_nb_coeff, get_maximas from nibabel.streamlines import TrkFile from nibabel.streamlines.tractogram import LazyTractogram, TractogramItem @@ -355,14 +355,14 @@ def get_direction_getter(in_img, algo, sphere, sub_sphere, theta, sh_basis, peak_values = np.zeros((img_shape_3d + (npeaks, ))) peak_indices = np.full((img_shape_3d + (npeaks, )), -1, dtype='int') - b_matrix = get_b_matrix( - find_order_from_nb_coeff(img_data), sphere, sh_basis, - legacy=is_legacy) + b_matrix, _ = sh_to_sf_matrix(sphere, + find_order_from_nb_coeff(img_data), + sh_basis, legacy=is_legacy) for idx in np.argwhere(np.sum(img_data, axis=-1)): idx = tuple(idx) directions, values, indices = get_maximas(img_data[idx], - sphere, b_matrix, + sphere, b_matrix.T, sf_threshold, 0) if values.shape[0] != 0: n = min(npeaks, values.shape[0]) diff --git a/scilpy/tractanalysis/afd_along_streamlines.py b/scilpy/tractanalysis/afd_along_streamlines.py index 9b84633b3..fe401f597 100644 --- a/scilpy/tractanalysis/afd_along_streamlines.py +++ b/scilpy/tractanalysis/afd_along_streamlines.py @@ -1,10 +1,11 @@ # -*- coding: utf-8 -*- from dipy.data import get_sphere +from dipy.reconst.shm import sh_to_sf_matrix, sph_harm_ind_list import numpy as np from scipy.special import lpn -from scilpy.reconst.utils import find_order_from_nb_coeff, get_b_matrix +from scilpy.reconst.utils import find_order_from_nb_coeff from scilpy.tractanalysis.grid_intersections import grid_intersections @@ -83,8 +84,8 @@ def afd_and_rd_sums_along_streamlines(sft, fodf, fodf_basis, fodf_data = fodf.get_fdata(dtype=np.float32) order = find_order_from_nb_coeff(fodf_data) sphere = get_sphere('repulsion724') - b_matrix, _, n = get_b_matrix(order, sphere, fodf_basis, return_all=True, - is_legacy=is_legacy) + b_matrix, _ = sh_to_sf_matrix(sphere, order, fodf_basis, legacy=is_legacy) + n = sph_harm_ind_list(fodf_basis) legendre0_at_n = lpn(order, 0)[0][n] sphere_norm = np.linalg.norm(sphere.vertices) @@ -123,7 +124,7 @@ def afd_and_rd_sums_along_streamlines(sft, fodf, fodf_basis, closest_vertex_indices, normalization_weights): vox_idx = tuple(vox_idx) - b_at_idx = b_matrix[closest_vertex_index] + b_at_idx = b_matrix.T[closest_vertex_index] fodf_at_index = fodf_data[vox_idx] afd_val = np.dot(b_at_idx, fodf_at_index) From 82c344c8f0a8822117152a1fbc3046f447828626 Mon Sep 17 00:00:00 2001 From: karp2601 Date: Thu, 29 Feb 2024 13:03:07 -0500 Subject: [PATCH 15/16] Fixing afd bug --- scilpy/tractanalysis/afd_along_streamlines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scilpy/tractanalysis/afd_along_streamlines.py b/scilpy/tractanalysis/afd_along_streamlines.py index fe401f597..e5f8d6aa8 100644 --- a/scilpy/tractanalysis/afd_along_streamlines.py +++ b/scilpy/tractanalysis/afd_along_streamlines.py @@ -85,7 +85,7 @@ def afd_and_rd_sums_along_streamlines(sft, fodf, fodf_basis, order = find_order_from_nb_coeff(fodf_data) sphere = get_sphere('repulsion724') b_matrix, _ = sh_to_sf_matrix(sphere, order, fodf_basis, legacy=is_legacy) - n = sph_harm_ind_list(fodf_basis) + _, n = sph_harm_ind_list(order) legendre0_at_n = lpn(order, 0)[0][n] sphere_norm = np.linalg.norm(sphere.vertices) From e5bc15523f216d48d249e38ea6e11ea318fca552 Mon Sep 17 00:00:00 2001 From: Philippe Karan Date: Thu, 29 Feb 2024 19:48:28 -0500 Subject: [PATCH 16/16] Doc --- scilpy/io/utils.py | 2 ++ scripts/scil_sh_convert.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/scilpy/io/utils.py b/scilpy/io/utils.py index 3827cccbd..90c9d1b98 100644 --- a/scilpy/io/utils.py +++ b/scilpy/io/utils.py @@ -243,6 +243,8 @@ def add_bbox_arg(parser): def add_sh_basis_args(parser, mandatory=False, input_output=False): """ Add spherical harmonics (SH) bases argument. + For more information about the bases, see + https://docs.dipy.org/stable/theory/sh_basis.html. Parameters ---------- diff --git a/scripts/scil_sh_convert.py b/scripts/scil_sh_convert.py index c2dc81c0e..5cd192521 100755 --- a/scripts/scil_sh_convert.py +++ b/scripts/scil_sh_convert.py @@ -6,7 +6,7 @@ 'descoteaux07', 'descoteaux07_legacy', 'tournier07' or 'tournier07_legacy'. Using the sh_basis argument, both the input and the output SH bases must be given, in the order. For more information about the bases, see -https://dipy.org/documentation/1.4.0./theory/sh_basis/. +https://docs.dipy.org/stable/theory/sh_basis.html. Formerly: scil_convert_sh_basis.py """