Skip to content

Commit

Permalink
Merge pull request scilus#921 from karanphil/add_sh_basis_legacy_support
Browse files Browse the repository at this point in the history
Add SH basis legacy support
  • Loading branch information
arnaudbore authored Mar 1, 2024
2 parents 1c3029a + 1bb9898 commit d319ad2
Show file tree
Hide file tree
Showing 34 changed files with 357 additions and 194 deletions.
30 changes: 24 additions & 6 deletions scilpy/denoise/asym_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
def angle_aware_bilateral_filtering(in_sh, sh_order=8,
sh_basis='descoteaux07',
in_full_basis=False,
is_legacy=True,
sphere_str='repulsion724',
sigma_spatial=1.0, sigma_angular=1.0,
sigma_range=0.5, use_gpu=True):
Expand All @@ -27,6 +28,8 @@ def angle_aware_bilateral_filtering(in_sh, sh_order=8,
Name of SH basis used.
in_full_basis: bool, optional
True if input is expressed in full SH basis.
is_legacy : bool, optional
Whether or not the SH basis is in its legacy form.
sphere_str: str, optional
Name of the DIPY sphere to use for sh to sf projection.
sigma_spatial: float, optional
Expand All @@ -46,6 +49,7 @@ def angle_aware_bilateral_filtering(in_sh, sh_order=8,
if use_gpu and have_opencl:
return angle_aware_bilateral_filtering_gpu(in_sh, sh_order,
sh_basis, in_full_basis,
is_legacy,
sphere_str, sigma_spatial,
sigma_angular, sigma_range)
elif use_gpu and not have_opencl:
Expand All @@ -54,13 +58,15 @@ def angle_aware_bilateral_filtering(in_sh, sh_order=8,
else:
return angle_aware_bilateral_filtering_cpu(in_sh, sh_order,
sh_basis, in_full_basis,
is_legacy,
sphere_str, sigma_spatial,
sigma_angular, sigma_range)


def angle_aware_bilateral_filtering_gpu(in_sh, sh_order=8,
sh_basis='descoteaux07',
in_full_basis=False,
is_legacy=True,
sphere_str='repulsion724',
sigma_spatial=1.0,
sigma_angular=1.0,
Expand All @@ -78,6 +84,8 @@ def angle_aware_bilateral_filtering_gpu(in_sh, sh_order=8,
Name of SH basis used.
in_full_basis: bool, optional
True if input is expressed in full SH basis.
is_legacy : bool, optional
Whether or not the SH basis is in its legacy form.
sphere_str: str, optional
Name of the DIPY sphere to use for sh to sf projection.
sigma_spatial: float, optional
Expand All @@ -104,11 +112,13 @@ def angle_aware_bilateral_filtering_gpu(in_sh, sh_order=8,
sh_to_sf_mat = sh_to_sf_matrix(sphere, sh_order=sh_order,
basis_type=sh_basis,
full_basis=in_full_basis,
legacy=is_legacy,
return_inv=False)

_, sf_to_sh_mat = sh_to_sf_matrix(sphere, sh_order=sh_order,
basis_type=sh_basis,
full_basis=True,
legacy=is_legacy,
return_inv=True)

out_n_coeffs = sf_to_sh_mat.shape[1]
Expand Down Expand Up @@ -150,6 +160,7 @@ def angle_aware_bilateral_filtering_gpu(in_sh, sh_order=8,
def angle_aware_bilateral_filtering_cpu(in_sh, sh_order=8,
sh_basis='descoteaux07',
in_full_basis=False,
is_legacy=True,
sphere_str='repulsion724',
sigma_spatial=1.0,
sigma_angular=1.0,
Expand All @@ -168,6 +179,8 @@ def angle_aware_bilateral_filtering_cpu(in_sh, sh_order=8,
Name of SH basis used.
in_full_basis: bool, optional
True if input is expressed in full SH basis.
is_legacy : bool, optional
Whether or not the SH basis is in its legacy form.
sphere_str: str, optional
Name of the DIPY sphere to use for sh to sf projection.
sigma_spatial: float, optional
Expand All @@ -194,7 +207,8 @@ def angle_aware_bilateral_filtering_cpu(in_sh, sh_order=8,

nb_sf = len(sphere.vertices)
B = sh_to_sf_matrix(sphere, sh_order=sh_order, basis_type=sh_basis,
return_inv=False, full_basis=in_full_basis)
return_inv=False, full_basis=in_full_basis,
legacy=is_legacy)

mean_sf = np.zeros(in_sh.shape[:-1] + (nb_sf,))

Expand All @@ -209,7 +223,7 @@ def angle_aware_bilateral_filtering_cpu(in_sh, sh_order=8,

# Convert back to SH coefficients
_, B_inv = sh_to_sf_matrix(sphere, sh_order=sh_order, basis_type=sh_basis,
full_basis=True)
full_basis=True, legacy=is_legacy)
out_sh = np.array([np.dot(i, B_inv) for i in mean_sf], dtype=in_sh.dtype)
# By default, return only asymmetric SH
return out_sh
Expand Down Expand Up @@ -371,7 +385,7 @@ def _correlate_spatial(image_u, h_filter, sigma_range):


def cosine_filtering(in_sh, sh_order=8, sh_basis='descoteaux07',
in_full_basis=False, dot_sharpness=1.0,
in_full_basis=False, is_legacy=True, dot_sharpness=1.0,
sphere_str='repulsion724', sigma=1.0):
"""
Average the SH projected on a sphere using a first-neighbor gaussian
Expand All @@ -389,6 +403,8 @@ def cosine_filtering(in_sh, sh_order=8, sh_basis='descoteaux07',
SH basis of the input signal.
in_full_basis: bool, optional
True if the input is in full SH basis.
is_legacy : bool, optional
Whether or not the SH basis is in its legacy form.
dot_sharpness: float, optional
Exponent of the dot product. When set to 0.0, directions
are not weighted by the dot product.
Expand All @@ -411,13 +427,14 @@ def cosine_filtering(in_sh, sh_order=8, sh_basis='descoteaux07',
nb_sf = len(sphere.vertices)
mean_sf = np.zeros(np.append(in_sh.shape[:-1], nb_sf))
B = sh_to_sf_matrix(sphere, sh_order=sh_order, basis_type=sh_basis,
return_inv=False, full_basis=in_full_basis)
return_inv=False, full_basis=in_full_basis,
legacy=is_legacy)

# We want a B matrix to project on an inverse sphere to have the sf on
# the opposite hemisphere for a given vertice
neg_B = sh_to_sf_matrix(Sphere(xyz=-sphere.vertices), sh_order=sh_order,
basis_type=sh_basis, return_inv=False,
full_basis=in_full_basis)
full_basis=in_full_basis, legacy=is_legacy)

# Apply filter to each sphere vertice
for sf_i in range(nb_sf):
Expand All @@ -435,7 +452,8 @@ def cosine_filtering(in_sh, sh_order=8, sh_basis='descoteaux07',
# Convert back to SH coefficients
_, B_inv = sh_to_sf_matrix(sphere, sh_order=sh_order,
basis_type=sh_basis,
full_basis=True)
full_basis=True,
legacy=is_legacy)

out_sh = np.array([np.dot(i, B_inv) for i in mean_sf], dtype=in_sh.dtype)
return out_sh
Expand Down
4 changes: 2 additions & 2 deletions scilpy/denoise/tests/test_asym_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_angle_aware_bilateral_filtering():

sh_order, full_basis = get_sh_order_and_fullness(in_sh.shape[-1])
out = angle_aware_bilateral_filtering_cpu(in_sh, sh_order,
sh_basis, full_basis,
sh_basis, full_basis, True,
sphere_str, sigma_spatial,
sigma_angular, sigma_range)

Expand All @@ -40,7 +40,7 @@ def test_cosine_filtering():
sharpness = 1.0

sh_order, full_basis = get_sh_order_and_fullness(in_sh.shape[-1])
out = cosine_filtering(in_sh, sh_order, sh_basis, full_basis,
out = cosine_filtering(in_sh, sh_order, sh_basis, full_basis, True,
sharpness, sphere_str, sigma_spatial)

assert np.allclose(out, fodf_3x3_order8_descoteaux07_filtered_cosine)
79 changes: 69 additions & 10 deletions scilpy/io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,36 +283,95 @@ def add_bbox_arg(parser):
'streamlines).')


def add_sh_basis_args(parser, mandatory=False):
"""Add spherical harmonics (SH) bases argument.
def add_sh_basis_args(parser, mandatory=False, input_output=False):
"""
Add spherical harmonics (SH) bases argument. For more information about
the bases, see https://docs.dipy.org/stable/theory/sh_basis.html.
Parameters
----------
parser: argparse.ArgumentParser object
Parser.
mandatory: bool
Whether this argument is mandatory.
input_output: bool
Whether this argument should expect both input and output bases or not.
If set, the sh_basis argument will expect first the input basis,
followed by the output basis.
"""
choices = ['descoteaux07', 'tournier07']
def_val = 'descoteaux07'
if input_output:
nargs = 2
def_val = ['descoteaux07_legacy', 'tournier07']
input_output_msg = '\nBoth the input and output bases are ' +\
'required, in that order.'
else:
nargs = 1
def_val = ['descoteaux07_legacy']
input_output_msg = ''

choices = ['descoteaux07', 'tournier07', 'descoteaux07_legacy',
'tournier07_legacy']
help_msg = 'Spherical harmonics basis used for the SH coefficients. ' +\
'\nMust be either \'descoteaux07\' or \'tournier07\'' +\
input_output_msg +\
'\nMust be either \'descoteaux07\', \'tournier07\', \n' +\
'\'descoteaux07_legacy\' or \'tournier07_legacy\'' +\
' [%(default)s]:\n' +\
' \'descoteaux07\': SH basis from the Descoteaux et al.\n' +\
' MRM 2007 paper\n' +\
' \'tournier07\' : SH basis from the Tournier et al.\n' +\
' NeuroImage 2007 paper.'
' \'descoteaux07\' : SH basis from the Descoteaux ' +\
'et al.\n' +\
' MRM 2007 paper\n' +\
' \'tournier07\' : SH basis from the new ' +\
'Tournier et al.\n' +\
' NeuroImage 2019 paper, as in ' +\
'MRtrix 3.\n' +\
' \'descoteaux07_legacy\': SH basis from the legacy Dipy ' +\
'implementation\n' +\
' of the ' +\
'Descoteaux et al. MRM 2007 paper\n' +\
' \'tournier07_legacy\' : SH basis from the legacy ' +\
'Tournier et al.\n' +\
' NeuroImage 2007 paper.'

if mandatory:
arg_name = 'sh_basis'
else:
arg_name = '--sh_basis'

parser.add_argument(arg_name,
parser.add_argument(arg_name, nargs=nargs,
choices=choices, default=def_val,
help=help_msg)


def parse_sh_basis_arg(args):
"""
Parser the input from args.sh_basis. If two SH bases are given,
both input/output sh_basis and is_legacy are returned.
Parameters
----------
args : ArgumentParser.parse_args
ArgumentParser.parse_args from a script.
Returns
-------
sh_basis : string
Spherical harmonic basis name.
is_legacy : bool
Whether or not the SH basis is in its legacy form.
"""
sh_basis_name = args.sh_basis[0]
sh_basis = 'descoteaux07' if 'descoteaux07' in sh_basis_name \
else 'tournier07'
is_legacy = 'legacy' in sh_basis_name
if len(args.sh_basis) == 2:
sh_basis_name = args.sh_basis[1]
out_sh_basis = 'descoteaux07' if 'descoteaux07' in sh_basis_name \
else 'tournier07'
is_out_legacy = 'legacy' in sh_basis_name
return sh_basis, is_legacy, out_sh_basis, is_out_legacy
else:
return sh_basis, is_legacy


def add_nifti_screenshot_default_args(
parser, slice_ids_mandatory=True, transparency_mask_mandatory=True
):
Expand Down
18 changes: 12 additions & 6 deletions scilpy/reconst/fodf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
from dipy.data import get_sphere
from dipy.reconst.mcsd import MSDeconvFit
from dipy.reconst.multi_voxel import MultiVoxelFit
from dipy.reconst.shm import sh_to_sf_matrix

from scilpy.reconst.utils import find_order_from_nb_coeff, get_b_matrix
from scilpy.reconst.utils import find_order_from_nb_coeff

from dipy.utils.optpkg import optional_package
cvx, have_cvxpy, _ = optional_package("cvxpy")


def get_ventricles_max_fodf(data, fa, md, zoom, args):
def get_ventricles_max_fodf(data, fa, md, zoom, sh_basis, args,
is_legacy=True):
"""
Compute mean maximal fodf value in ventricules. Given
heuristics thresholds on FA and MD values, finds the
Expand All @@ -30,9 +32,13 @@ def get_ventricles_max_fodf(data, fa, md, zoom, args):
FA (Fractional Anisotropy) volume from DTI
md: ndarray (x, y, z)
MD (Mean Diffusivity) volume from DTI
vol: int > 0
Maximum Nnumber of voxels used to compute the mean.
zoom: int > 0
Maximum number of voxels used to compute the mean.
1000 works well at 2x2x2 = 8 mm3
sh_basis: str
Either 'tournier07' or 'descoteaux07'
is_legacy : bool, optional
Whether or not the SH basis is in its legacy form.
Returns
-------
Expand All @@ -42,7 +48,7 @@ def get_ventricles_max_fodf(data, fa, md, zoom, args):

order = find_order_from_nb_coeff(data)
sphere = get_sphere('repulsion100')
b_matrix = get_b_matrix(order, sphere, args.sh_basis)
b_matrix, _ = sh_to_sf_matrix(sphere, order, sh_basis, legacy=is_legacy)
sum_of_max = 0
count = 0

Expand Down Expand Up @@ -86,7 +92,7 @@ def get_ventricles_max_fodf(data, fa, md, zoom, args):
continue
if fa[i, j, k] < args.fa_threshold \
and md[i, j, k] > args.md_threshold:
sf = np.dot(data[i, j, k], b_matrix.T)
sf = np.dot(data[i, j, k], b_matrix)
sum_of_max += sf.max()
count += 1
mask[i, j, k] = 1
Expand Down
Loading

0 comments on commit d319ad2

Please sign in to comment.