diff --git a/pyproject.toml b/pyproject.toml index 3445268..5bba304 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "YOEO" -version = "1.1.0" +version = "1.1.1" description = "A hybrid CNN for object detection and semantic segmentation" authors = ["Florian Vahl ", "Jan Gutsche "] diff --git a/yoeo/models.py b/yoeo/models.py index d818c69..1b69cd8 100644 --- a/yoeo/models.py +++ b/yoeo/models.py @@ -179,7 +179,7 @@ def forward(self, x): if self.training: return x else: - return torch.argmax(x, dim=1) + return torch.argmax(x, dim=1).to(torch.uint8) class Darknet(nn.Module):