diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 5ec0da91..7f565663 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -289,7 +289,7 @@ def sample(self, testing=False, use_mean=False, sample_x_stddev_from_mean=None): return self.fcomb.forward(self.unet_features, z_prior) - def reconstruct(self, use_posterior_mean=False, z_posterior=None): + def reconstruct(self, use_posterior_mean=False, z_posterior=None, sample_x_stddev_from_mean=None): """ Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map @@ -298,6 +298,15 @@ def reconstruct(self, use_posterior_mean=False, z_posterior=None): """ if use_posterior_mean: z_posterior = self.posterior_latent_space.mean + elif sample_x_stddev_from_mean is not None: + if isinstance(sample_x_stddev_from_mean, list): + sample_x_stddev_from_mean = torch.Tensor(sample_x_stddev_from_mean) + sample_x_stddev_from_mean = sample_x_stddev_from_mean.to( + self.posterior_latent_space.base_dist.stddev.device + ) + z_posterior = self.posterior_latent_space.base_dist.loc + ( + self.posterior_latent_space.base_dist.scale * sample_x_stddev_from_mean + ) else: if z_posterior is None: z_posterior = self.posterior_latent_space.rsample() diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index b230b3f1..fb576803 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -242,6 +242,7 @@ def infer( self, img, context_map=None, + seg=None, num_samples=1, sample_strategy="mean", latent_dim=True, @@ -314,22 +315,34 @@ def infer( intensity_window=self.hparams.intensity_window, ) + + img_arr = sitk.GetArrayFromImage(img) if context_map is not None: context_map = resample_mask_to_image(img, context_map) cmap_arr = sitk.GetArrayFromImage(img) + if seg is not None: + seg = resample_mask_to_image(img, seg) + seg_arr = sitk.GetArrayFromImage(img) + if self.hparams.ndims == 2: slices = [img_arr[z, :, :] for z in range(img_arr.shape[0])] if context_map is not None: cmap_slices = [cmap_arr[z, :, :] for z in range(cmap_arr.shape[0])] + + if seg is not None: + seg_slices = [seg_arr[z, :, :] for z in range(seg_arr.shape[0])] else: slices = [img_arr] if context_map is not None: cmap_slices = [cmap_arr] + if seg is not None: + seg_slices = [seg_arr] + for idx, i in enumerate(slices): x = torch.Tensor(i).to(self.device) x = x.unsqueeze(0) @@ -342,19 +355,37 @@ def infer( x = torch.cat((x, c), dim=1) + if seg is not None: + s = torch.Tensor(seg_slices[idx]).to(self.device) + s = s.unsqueeze(0) + s = s.unsqueeze(0) + if self.hparams.prob_type == "prob": - self.prob_unet.forward(x) + if seg is not None: + self.prob_unet.forward(img, seg=seg, training=True) + else: + self.prob_unet.forward(x) for sample in samples: if self.hparams.prob_type == "prob": if sample["name"] == "mean": - y = self.prob_unet.sample(testing=True, use_mean=True) + if seg is None: + y = self.prob_unet.sample(testing=True, use_mean=True) + else: + y = self.prob_unet.reconstruct(use_posterior_mean=True) else: - y = self.prob_unet.sample( - testing=True, - use_mean=False, - sample_x_stddev_from_mean=sample["std_dev_from_mean"], - ) + if seg is None: + y = self.prob_unet.sample( + testing=True, + use_mean=False, + sample_x_stddev_from_mean=sample["std_dev_from_mean"], + ) + else: + y = self.prob_unet.reconstruct( + use_posterior_mean=False, + sample_x_stddev_from_mean=sample["std_dev_from_mean"], + ) + # else: # if sample["name"] == "mean": # y = self.prob_unet.sample(x, mean=True)