Skip to content

Commit

Permalink
Add background channel
Browse files Browse the repository at this point in the history
  • Loading branch information
pchlap committed Apr 1, 2024
1 parent f20acf6 commit ec18dbc
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions platipy/imaging/cnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,11 @@ def infer(
s = s.unsqueeze(0)
s = s.unsqueeze(0)

# Add in background channel
not_s = 1 - sample_strategy.max(axis=1).values
not_s = torch.unsqueeze(not_s, dim=1)
s = torch.cat((not_s, s), dim=1).float()

if self.hparams.prob_type == "prob":
if seg is not None:
self.prob_unet.forward(x, seg=s, training=True)
Expand Down

0 comments on commit ec18dbc

Please sign in to comment.