Skip to content

Commit

Permalink
Allow inference using a reference segmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
pchlap committed Apr 1, 2024
1 parent a128cea commit b56a876
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 8 deletions.
11 changes: 10 additions & 1 deletion platipy/imaging/cnn/prob_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def sample(self, testing=False, use_mean=False, sample_x_stddev_from_mean=None):

return self.fcomb.forward(self.unet_features, z_prior)

def reconstruct(self, use_posterior_mean=False, z_posterior=None):
def reconstruct(self, use_posterior_mean=False, z_posterior=None, sample_x_stddev_from_mean=None):
"""
Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet
feature map
Expand All @@ -298,6 +298,15 @@ def reconstruct(self, use_posterior_mean=False, z_posterior=None):
"""
if use_posterior_mean:
z_posterior = self.posterior_latent_space.mean
elif sample_x_stddev_from_mean is not None:
if isinstance(sample_x_stddev_from_mean, list):
sample_x_stddev_from_mean = torch.Tensor(sample_x_stddev_from_mean)
sample_x_stddev_from_mean = sample_x_stddev_from_mean.to(
self.posterior_latent_space.base_dist.stddev.device
)
z_posterior = self.posterior_latent_space.base_dist.loc + (
self.posterior_latent_space.base_dist.scale * sample_x_stddev_from_mean
)
else:
if z_posterior is None:
z_posterior = self.posterior_latent_space.rsample()
Expand Down
45 changes: 38 additions & 7 deletions platipy/imaging/cnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def infer(
self,
img,
context_map=None,
seg=None,
num_samples=1,
sample_strategy="mean",
latent_dim=True,
Expand Down Expand Up @@ -314,22 +315,34 @@ def infer(
intensity_window=self.hparams.intensity_window,
)



img_arr = sitk.GetArrayFromImage(img)

if context_map is not None:
context_map = resample_mask_to_image(img, context_map)
cmap_arr = sitk.GetArrayFromImage(img)

if seg is not None:
seg = resample_mask_to_image(img, seg)
seg_arr = sitk.GetArrayFromImage(img)

if self.hparams.ndims == 2:
slices = [img_arr[z, :, :] for z in range(img_arr.shape[0])]

if context_map is not None:
cmap_slices = [cmap_arr[z, :, :] for z in range(cmap_arr.shape[0])]

if seg is not None:
seg_slices = [seg_arr[z, :, :] for z in range(seg_arr.shape[0])]
else:
slices = [img_arr]
if context_map is not None:
cmap_slices = [cmap_arr]

if seg is not None:
seg_slices = [seg_arr]

for idx, i in enumerate(slices):
x = torch.Tensor(i).to(self.device)
x = x.unsqueeze(0)
Expand All @@ -342,19 +355,37 @@ def infer(

x = torch.cat((x, c), dim=1)

if seg is not None:
s = torch.Tensor(seg_slices[idx]).to(self.device)
s = s.unsqueeze(0)
s = s.unsqueeze(0)

if self.hparams.prob_type == "prob":
self.prob_unet.forward(x)
if seg is not None:
self.prob_unet.forward(img, seg=seg, training=True)
else:
self.prob_unet.forward(x)

for sample in samples:
if self.hparams.prob_type == "prob":
if sample["name"] == "mean":
y = self.prob_unet.sample(testing=True, use_mean=True)
if seg is None:
y = self.prob_unet.sample(testing=True, use_mean=True)
else:
y = self.prob_unet.reconstruct(use_posterior_mean=True)
else:
y = self.prob_unet.sample(
testing=True,
use_mean=False,
sample_x_stddev_from_mean=sample["std_dev_from_mean"],
)
if seg is None:
y = self.prob_unet.sample(
testing=True,
use_mean=False,
sample_x_stddev_from_mean=sample["std_dev_from_mean"],
)
else:
y = self.prob_unet.reconstruct(
use_posterior_mean=False,
sample_x_stddev_from_mean=sample["std_dev_from_mean"],
)

# else:
# if sample["name"] == "mean":
# y = self.prob_unet.sample(x, mean=True)
Expand Down

0 comments on commit b56a876

Please sign in to comment.