Skip to content

Commit

Permalink
Merge pull request #198 from pyplati/visualisation-and-registration-u…
Browse files Browse the repository at this point in the history
…pdates

Visualisation and registration updates
  • Loading branch information
pchlap authored Sep 11, 2023
2 parents d2ade8b + 796a191 commit 1bca37e
Show file tree
Hide file tree
Showing 4 changed files with 342 additions and 145 deletions.
131 changes: 97 additions & 34 deletions platipy/imaging/registration/deformable.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
control_point_spacing_distance_to_number,
)

from platipy.imaging import ImageVisualiser
from platipy.imaging.label.utils import get_com


def multiscale_demons(
registration_algorithm,
Expand All @@ -35,6 +38,7 @@ def multiscale_demons(
resolution_staging=None,
smoothing_sigmas=None,
iteration_staging=None,
interp_order=sitk.sitkLinear,
):
"""
Run the given registration algorithm in a multiscale fashion. The original scale should not be
Expand All @@ -60,7 +64,7 @@ def multiscale_demons(
fixed_images = []
moving_images = []

for resolution, smoothing_sigma in reversed(list(zip(resolution_staging, smoothing_sigmas))):
for resolution, smoothing_sigma in zip(resolution_staging, smoothing_sigmas):

isotropic_voxel_size_mm = None
shrink_factor = None
Expand All @@ -76,14 +80,17 @@ def multiscale_demons(
isotropic_voxel_size_mm=isotropic_voxel_size_mm,
shrink_factor=shrink_factor,
smoothing_sigma=smoothing_sigma,
interpolator=interp_order,
)
)

moving_images.append(
smooth_and_resample(
moving_image,
isotropic_voxel_size_mm=isotropic_voxel_size_mm,
shrink_factor=shrink_factor,
smoothing_sigma=smoothing_sigma,
interpolator=interp_order,
)
)

Expand All @@ -95,51 +102,94 @@ def multiscale_demons(
initial_displacement_field = sitk.TransformToDisplacementField(
initial_transform,
sitk.sitkVectorFloat64,
fixed_images[-1].GetSize(),
fixed_images[-1].GetOrigin(),
fixed_images[-1].GetSpacing(),
fixed_images[-1].GetDirection(),
fixed_image.GetSize(),
fixed_image.GetOrigin(),
fixed_image.GetSpacing(),
fixed_image.GetDirection(),
)
else:
if len(moving_image.GetSize()) == 2:
initial_displacement_field = sitk.Image(
fixed_images[-1].GetWidth(),
fixed_images[-1].GetHeight(),
fixed_image.GetWidth(),
fixed_image.GetHeight(),
sitk.sitkVectorFloat64,
)
elif len(moving_image.GetSize()) == 3:
initial_displacement_field = sitk.Image(
fixed_images[-1].GetWidth(),
fixed_images[-1].GetHeight(),
fixed_images[-1].GetDepth(),
fixed_image.GetWidth(),
fixed_image.GetHeight(),
fixed_image.GetDepth(),
sitk.sitkVectorFloat64,
)
initial_displacement_field.CopyInformation(fixed_images[-1])
initial_displacement_field.CopyInformation(fixed_image)
else:
initial_displacement_field = sitk.Resample(initial_displacement_field, fixed_images[-1])
initial_displacement_field = sitk.Resample(
initial_displacement_field, fixed_image
)

# Run the registration.
iters = iteration_staging[0]
registration_algorithm.SetNumberOfIterations(iters)
initial_displacement_field = registration_algorithm.Execute(
fixed_images[-1], moving_images[-1], initial_displacement_field
)
# Start at the top of the pyramid and work our way down.
for i, (f_image, m_image) in enumerate(
reversed(list(zip(fixed_images[0:-1], moving_images[0:-1])))
):
initial_displacement_field = sitk.Resample(initial_displacement_field, f_image)
iters = iteration_staging[i + 1]
registration_algorithm.SetNumberOfIterations(iters)
initial_displacement_field = registration_algorithm.Execute(
f_image, m_image, initial_displacement_field

dvf_total = sitk.Resample(initial_displacement_field, fixed_image)

for i in range(len(fixed_images)):
f_image = fixed_images[i]
m_image = moving_images[i]

# we now apply the (total) transform to the moving image
dvf_total = sitk.Resample(dvf_total, f_image)

tfm_total = sitk.DisplacementFieldTransform(
sitk.Cast(dvf_total, sitk.sitkVectorFloat64)
)
m_image = sitk.Resample(m_image, tfm_total, interp_order)

output_displacement_field = sitk.Resample(
initial_displacement_field, initial_displacement_field
)
# set up iteration staging
iters = iteration_staging[i]
registration_algorithm.SetNumberOfIterations(iters)

# set up regularisation
# leave constant (in image units) for now

dvf_iter = registration_algorithm.Execute(f_image, m_image)

# and now add to the running DVF
# importly, at each voxel the deformation vector (source -> destination)
# has to be updated with the vector field itself
dvf_total = dvf_total + sitk.Resample(dvf_iter, tfm_total)

return sitk.DisplacementFieldTransform(initial_displacement_field), output_displacement_field
# manually smooth the DVF
sigma = registration_algorithm.GetStandardDeviations()
dvf_total = sitk.SmoothingRecursiveGaussian(dvf_total, sigma)
dvf_total = sitk.Cast(dvf_total, sitk.sitkVectorFloat64)

# vis = ImageVisualiser(f_image, cut=get_com(f_image), figure_size_in=6)
# vis.add_comparison_overlay(m_image)

# vis.add_vector_overlay(
# dvf_iter,
# arrow_scale=0.25,
# arrow_width=0.25,
# subsample=4,
# )

# vis.set_limits_from_label(f_image, expansion=100)
# fig = vis.show()

# test_tfm = sitk.DisplacementFieldTransform(
# sitk.Cast(dvf_iter, sitk.sitkVectorFloat64)
# )
# test = sitk.Resample(m_image, test_tfm)

# vis = ImageVisualiser(f_image > 0, cut=get_com(f_image), figure_size_in=6)
# vis.add_comparison_overlay(test > 0)

# vis.set_limits_from_label(f_image, expansion=100)
# fig = vis.show()

dvf_total = sitk.Resample(dvf_total, fixed_image)

return dvf_total


def fast_symmetric_forces_demons_registration(
Expand All @@ -149,6 +199,7 @@ def fast_symmetric_forces_demons_registration(
iteration_staging=[10, 10, 10],
isotropic_resample=False,
initial_displacement_field=None,
regularisation_kernel_mm=1.5,
smoothing_sigma_factor=1,
smoothing_sigmas=False,
default_value=None,
Expand All @@ -169,6 +220,7 @@ def fast_symmetric_forces_demons_registration(
case resolution_staging is used to define voxel size
(mm) per level
initial_displacement_field (sitk.Image) : Initial displacement field to use
regularisation_kernel_scale (float) : Relative scale (var/voxel size) of the regularisation kernel (Gaussian)
ncores (int) : number of processing cores to use
smoothing_sigma_factor (float) : the relative width of the Gaussian smoothing kernel
interp_order (int) : the interpolation order
Expand Down Expand Up @@ -200,7 +252,14 @@ def fast_symmetric_forces_demons_registration(
registration_method.SetNumberOfThreads(ncores)
registration_method.SetSmoothUpdateField(True)
registration_method.SetSmoothDisplacementField(True)
registration_method.SetStandardDeviations(1.5)

# This is the regularisation kernel
# values are set in image (voxel) coordinates
regularisation_kernel_vox = np.array(regularisation_kernel_mm) / np.array(
fixed_image.GetSpacing()
)
print("regularisation_kernel_vox", regularisation_kernel_vox)
registration_method.SetStandardDeviations(regularisation_kernel_vox.tolist())

# This allows monitoring of the progress
if verbose:
Expand All @@ -212,7 +271,7 @@ def fast_symmetric_forces_demons_registration(
if not smoothing_sigmas:
smoothing_sigmas = [i * smoothing_sigma_factor for i in resolution_staging]

output_transform, deformation_field = multiscale_demons(
deformation_field = multiscale_demons(
registration_algorithm=registration_method,
fixed_image=fixed_image,
moving_image=moving_image,
Expand All @@ -221,6 +280,7 @@ def fast_symmetric_forces_demons_registration(
iteration_staging=iteration_staging,
isotropic_resample=isotropic_resample,
initial_displacement_field=initial_displacement_field,
interp_order=interp_order,
)

resampler = sitk.ResampleImageFilter()
Expand All @@ -237,15 +297,18 @@ def fast_symmetric_forces_demons_registration(

resampler.SetDefaultPixelValue(default_value)

# create the deformable transform
output_transform = sitk.DisplacementFieldTransform(
sitk.Cast(deformation_field, sitk.sitkVectorFloat64)
)

resampler.SetTransform(output_transform)
registered_image = resampler.Execute(moving_image)

registered_image.CopyInformation(fixed_image)
registered_image = sitk.Cast(registered_image, moving_image_type)

resampled_field = sitk.Resample(deformation_field, fixed_image)

return registered_image, output_transform, resampled_field
return registered_image, output_transform, deformation_field


def bspline_registration(
Expand Down
2 changes: 1 addition & 1 deletion platipy/imaging/registration/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def linear_registration(
# to do: add the rest

registration.SetInterpolator(sitk.sitkLinear) # Perhaps a small gain in improvement
registration.SetMetricSamplingPercentage(sampling_rate)
registration.SetMetricSamplingPercentage(sampling_rate, seed=42)
registration.SetMetricSamplingStrategy(sitk.ImageRegistrationMethod.REGULAR)

# This is only necessary if using a transform comprising changes with different units
Expand Down
4 changes: 3 additions & 1 deletion platipy/imaging/visualisation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
min_value=False,
max_value=False,
discrete_levels=False,
show_as_contours=False,
mid_ticks=False,
show_colorbar=True,
norm=None,
Expand All @@ -70,6 +71,7 @@ def __init__(
self.min_value = min_value
self.max_value = max_value
self.discrete_levels = discrete_levels
self.show_as_contours = show_as_contours
self.mid_ticks = mid_ticks
self.show_colorbar = show_colorbar
self.norm = norm
Expand Down Expand Up @@ -250,7 +252,7 @@ def reorientate_vector_field(axis, vector_ax, vector_cor, vector_sag, invert_fie
if axis == "y": # coronal projection
return vector_sag, vector_ax, vector_cor
if axis == "z": # axial projection
return vector_sag, -vector_cor, vector_ax
return -vector_sag, -vector_cor, vector_ax

return None

Expand Down
Loading

0 comments on commit 1bca37e

Please sign in to comment.