Skip to content

Commit

Permalink
Merge branch 'add-hd-95-metric' into prob-unet
Browse files Browse the repository at this point in the history
  • Loading branch information
pchlap committed Dec 21, 2023
2 parents d65ec7a + 34e2ebb commit a128cea
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 21 deletions.
13 changes: 6 additions & 7 deletions platipy/imaging/label/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def compute_surface_dsc(label_a, label_b, tau=3.0):

def compute_surface_metrics(label_a, label_b, verbose=False):
"""Compute surface distance metrics between two labels. Surface metrics computed are:
hausdorffDistance, meanSurfaceDistance, medianSurfaceDistance, maximumSurfaceDistance,
sigmaSurfaceDistance, surfaceDSC
hausdorffDistance, hausdorffDistance95, meanSurfaceDistance, medianSurfaceDistance,
maximumSurfaceDistance, sigmaSurfaceDistance, surfaceDSC
Args:
label_a (sitk.Image): A mask to compare
Expand All @@ -95,8 +95,7 @@ def compute_surface_metrics(label_a, label_b, verbose=False):
std_sd_list = []
median_sd_list = []
num_points = []
for (la, lb) in ((label_a, label_b), (label_b, label_a)):

for la, lb in ((label_a, label_b), (label_b, label_a)):
label_intensity_stat = sitk.LabelIntensityStatisticsImageFilter()
reference_distance_map = sitk.Abs(
sitk.SignedMaurerDistanceMap(
Expand All @@ -118,6 +117,7 @@ def compute_surface_metrics(label_a, label_b, verbose=False):

mean_surf_dist = np.dot(mean_sd_list, num_points) / np.sum(num_points)
max_surf_dist = np.max(max_sd_list)
hd_95 = np.percentile(max_sd_list, 95)
std_surf_dist = np.sqrt(
np.dot(
num_points,
Expand All @@ -131,6 +131,7 @@ def compute_surface_metrics(label_a, label_b, verbose=False):

result = {}
result["hausdorffDistance"] = hd
result["hausdorffDistance95"] = hd_95
result["meanSurfaceDistance"] = mean_surf_dist
result["medianSurfaceDistance"] = median_surf_dist
result["maximumSurfaceDistance"] = max_surf_dist
Expand Down Expand Up @@ -294,8 +295,7 @@ def compute_metric_masd(label_a, label_b, auto_crop=True):

mean_sd_list = []
num_points = []
for (la, lb) in ((label_a, label_b), (label_b, label_a)):

for la, lb in ((label_a, label_b), (label_b, label_a)):
label_intensity_stat = sitk.LabelIntensityStatisticsImageFilter()
reference_distance_map = sitk.Abs(
sitk.SignedMaurerDistanceMap(
Expand Down Expand Up @@ -364,7 +364,6 @@ def compute_apl(label_ref, label_test, distance_threshold_mm=3):

# iterate over each slice
for i in range(n_slices):

if (
sitk.GetArrayViewFromImage(label_ref)[i].sum()
+ sitk.GetArrayViewFromImage(label_test)[i].sum()
Expand Down
77 changes: 63 additions & 14 deletions services/nnunet/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,32 @@

import os
import subprocess
import json

from pathlib import Path

import logging
import SimpleITK as sitk

from platipy.backend import app, DataObject, celery # pylint: disable=unused-import
from platipy.backend import app, DataObject, celery # pylint: disable=unused-import

logger = logging.getLogger(__name__)

NNUNET_SETTINGS_DEFAULTS = {
"task": "TaskXXX",
"config": "2d",
"trainer": None,
"fold": None,
"clean_sup_slices": False,
}


def clean_sup_slices(mask):
lssif = sitk.LabelShapeStatisticsImageFilter()
max_slice_size = 0
sizes = {}
for z in range(mask.GetSize()[2]-1, -1, -1):
lssif.Execute(sitk.ConnectedComponent(mask[:,:,z]))
for z in range(mask.GetSize()[2] - 1, -1, -1):
lssif.Execute(sitk.ConnectedComponent(mask[:, :, z]))
if len(lssif.GetLabels()) == 0:
continue

Expand All @@ -47,13 +50,40 @@ def clean_sup_slices(mask):

sizes[z] = phys_size
for z in sizes:
if sizes[z] > max_slice_size/2:
mask[:,:,z+1:mask.GetSize()[2]] = 0
if sizes[z] > max_slice_size / 2:
mask[:, :, z + 1 : mask.GetSize()[2]] = 0
break

return mask


def get_structure_names(task):
# Look up structure names if we can find them dataset.json file
if "nnUNet_raw_data_base" not in os.environ:
logger.info("nnUNet_raw_data_base not set")
return {}

raw_path = Path(os.environ["nnUNet_raw_data_base"])
task_path = raw_path.joinpath("nnUNet_raw_data", task)
dataset_file = task_path.joinpath("dataset.json")

logger.info("Attempting to read %s", dataset_file)

if not dataset_file.exists():
logger.info("dataset.json file does not exist for %s", dataset_file)
return {}

dataset = {}
with open(dataset_file, "r") as f:
dataset = json.load(f)

if "labels" not in dataset:
logger.info("Something went wrong reading dataset.json file")
return {}

return dataset["labels"]


@app.register("nnUNet Service", default_settings=NNUNET_SETTINGS_DEFAULTS)
def nnunet_service(data_objects, working_dir, settings):
"""
Expand All @@ -72,8 +102,10 @@ def nnunet_service(data_objects, working_dir, settings):
output_path = Path(working_dir).joinpath("output")
output_path.mkdir()

for data_object in data_objects:
labels = get_structure_names(settings["task"])
logger.info("Read labels: %s", labels)

for data_object in data_objects:
# Create a symbolic link for each image to auto-segment using the nnUNet
do_path = Path(data_object.path)
io_path = input_path.joinpath(f"{settings['task']}_0000.nii.gz")
Expand All @@ -98,21 +130,39 @@ def nnunet_service(data_objects, working_dir, settings):
settings["config"],
]

if settings["trainer"]:
if "fold" in settings and settings["fold"]:
command += ["-f", settings["fold"]]

if "trainer" in settings and settings["trainer"]:
command += ["-tr", settings["trainer"]]

logger.info("Running command: %s", command)
subprocess.call(command)

for op in output_path.glob("*.nii.gz"):
label_map = sitk.ReadImage(str(op))

label_map_arr = sitk.GetArrayFromImage(label_map)
label_count = label_map_arr.max()

for label_id in range(1, label_count + 1):
mask = label_map == label_id

if settings["clean_sup_slices"]:
mask = sitk.ReadImage(str(op))
mask = clean_sup_slices(mask)
sitk.WriteImage(mask, str(op))
label_name = f"Structure_{label_id}"
if str(label_id) in labels:
label_name = labels[str(label_id)]

output_data_object = DataObject(type="FILE", path=str(op), parent=data_object)
output_objects.append(output_data_object)
if settings["clean_sup_slices"]:
mask = clean_sup_slices(mask)

mask_file = output_path.joinpath(f"{label_name}.nii.gz")

sitk.WriteImage(mask, str(mask_file))

output_data_object = DataObject(
type="FILE", path=str(mask_file), parent=data_object
)
output_objects.append(output_data_object)

os.remove(io_path)

Expand All @@ -122,7 +172,6 @@ def nnunet_service(data_objects, working_dir, settings):


if __name__ == "__main__":

# Run app by calling "python service.py" from the command line

DICOM_LISTENER_PORT = 7777
Expand Down

0 comments on commit a128cea

Please sign in to comment.