diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index 0d7c3a12..afffdc85 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -360,6 +360,11 @@ def infer( s = s.unsqueeze(0) s = s.unsqueeze(0) + # Add in background channel + not_s = 1 - sample_strategy.max(axis=1).values + not_s = torch.unsqueeze(not_s, dim=1) + s = torch.cat((not_s, s), dim=1).float() + if self.hparams.prob_type == "prob": if seg is not None: self.prob_unet.forward(x, seg=s, training=True)