From f00db0994fd60256d3a8a7b816a23649f71d6086 Mon Sep 17 00:00:00 2001 From: Riccardo Date: Tue, 4 Jul 2023 15:55:19 +0200 Subject: [PATCH] Added new unet with timm backbones --- models/__init__.py | 3 ++- models/timmunet.py | 55 ++++++++++++++++++++++++++++++++++++++++++++++ predict.py | 24 +++++++++++++++++--- requirements.txt | 3 ++- settings.py | 8 ++++--- train.py | 10 +++++++-- 6 files changed, 93 insertions(+), 10 deletions(-) create mode 100644 models/timmunet.py diff --git a/models/__init__.py b/models/__init__.py index 8c3e60b..a7cca17 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,2 +1,3 @@ from .baseline import UNetBaseline -from .tlmod import SimoidSegmentationModule \ No newline at end of file +from .tlmod import SimoidSegmentationModule +from .timmunet import UnetTimm \ No newline at end of file diff --git a/models/timmunet.py b/models/timmunet.py new file mode 100644 index 0000000..cc3e75e --- /dev/null +++ b/models/timmunet.py @@ -0,0 +1,55 @@ +import torch +import timm +from torchvision import transforms + +from models.baseline import UConvGroup + +class UnetTimm(torch.nn.Module): + def __init__(self, out_depth:int, backbone_name="efficientnet_b0", pretrained=True, decoder_scale = 1): + super().__init__() + self.backbone = timm.create_model( + backbone_name, + features_only=True, + pretrained=pretrained + ) + + self.upconvs = [] + + # get channels of backbone layers in inverted order (lower -> upper) + bb_channels = self.backbone.feature_info.channels()[::-1] + bb_channels.append(bb_channels[-1]) + + for i in range(len(bb_channels)-1): + if i == 0: + layer = UConvGroup(bb_channels[i], decoder_scale * bb_channels[i+1]) + else: + layer = UConvGroup((decoder_scale + 1) * bb_channels[i], decoder_scale * bb_channels[i+1]) + + self.upconvs.append(layer) + + self.upconvs = torch.nn.ModuleList(self.upconvs) + + self.normalize = transforms.Normalize( + mean=self.backbone.pretrained_cfg["mean"], + std=self.backbone.pretrained_cfg["std"], + ) + + self.out_conv = torch.nn.Conv2d(decoder_scale * bb_channels[-1], out_depth, kernel_size=3, padding=1) + + + + def forward(self, x): + #x = self.normalize(x) + features = self.backbone(x)[::-1] + + for i, f in enumerate(features): + if i == 0: + void_shape = list(f.shape) + void_shape[1] = 0 + p = self.upconvs[0](torch.empty(void_shape).to(x.device), f) + else: + p = self.upconvs[i](p, f) + + #print(f"{i}: {x.shape}") + + return self.out_conv(p) diff --git a/predict.py b/predict.py index 9dab47f..dbb58ba 100644 --- a/predict.py +++ b/predict.py @@ -12,13 +12,26 @@ from tqdm import tqdm import settings -from models import UNetBaseline +from models import UNetBaseline, UnetTimm + +available_models = { + "baseline": { + "model": UNetBaseline(in_depth=settings.input_channels, out_depth=1, depth_scale=settings.baseline_model_scale), + "checkpoint": settings.baseline_checkpoint_path + }, + "tunet": { + "model": UnetTimm(out_depth=1, backbone_name="efficientnet_b3", pretrained=False, decoder_scale=settings.timmunet_decoder_scale), + "checkpoint": settings.timmunet_checkpoint_path + } + +} # command line args parser = argparse.ArgumentParser(description='Segment pokemon cards') parser.add_argument("-file", dest="file", help="Input image", type=str, default=None) parser.add_argument("-folder", dest="folder", help="Folder where images (pngs or jpgs) are located", type=str, default=None) +parser.add_argument("-model", dest="model", help=f"Model name {available_models.keys()}", type=str, default="tunet") args = parser.parse_args() @@ -49,11 +62,16 @@ if ".png" in f or ".jpg" in f: files.append(os.path.join(args.folder, f)) +if args.model not in available_models.keys(): + print(f"No model named: {args.file}. Available models: {available_models.keys()}") + exit(1) + +print(f"Using model: {args.model}") # load model -model = UNetBaseline(in_depth=settings.input_channels, out_depth=1, depth_scale=settings.baseline_model_scale) +model = available_models[args.model]["model"] -checkpoint = torch.load(settings.checkpoint_path) +checkpoint = torch.load(available_models[args.model]["checkpoint"]) weights = checkpoint["state_dict"] for key in list(weights): weights[key.replace("model.", "")] = weights.pop(key) diff --git a/requirements.txt b/requirements.txt index 6aeced9..b792300 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ matplotlib==3.5.2 opencv-python-headless==4.7.0.72 pillow==9.5.0 tqdm>=4.64.0 -torchvision==0.15.2 \ No newline at end of file +torchvision==0.15.2 +timm==0.9.2 \ No newline at end of file diff --git a/settings.py b/settings.py index 8ae9264..2d7e996 100644 --- a/settings.py +++ b/settings.py @@ -10,7 +10,7 @@ output_root = "output" # global settings -card_size = (512, 368) +card_size = (512, 352) # data preprocessing settings min_mask_occupaion_percent = 0.20 @@ -30,11 +30,13 @@ max_epochs = 150 input_channels = 3 use_noisy_labels = False -learn_rate = 1e-3 +learn_rate = 1e-4 baseline_model_scale = 2 +timmunet_decoder_scale = 1 # prediction settings -checkpoint_path = "checkpoints/baseline_150.ckpt" +baseline_checkpoint_path = "checkpoints/baseline_150.ckpt" +timmunet_checkpoint_path = "checkpoints/timmunet_eff3_0819.ckpt" ################################################################################ diff --git a/train.py b/train.py index 2732ec1..2aac6e9 100644 --- a/train.py +++ b/train.py @@ -7,7 +7,7 @@ import settings from datasets import PkmCardSegmentationDataModule -from models import SimoidSegmentationModule, UNetBaseline +from models import SimoidSegmentationModule, UnetTimm # data agumentation @@ -25,7 +25,13 @@ use_noisy=settings.use_noisy_labels ) -torch_model = UNetBaseline(in_depth=settings.input_channels, out_depth=1, depth_scale=settings.baseline_model_scale) +# replace this line with the model you want to train +torch_model = UnetTimm( + out_depth=1, + backbone_name="efficientnet_b3", + pretrained=True, + decoder_scale=settings.timmunet_decoder_scale +) model = SimoidSegmentationModule(torch_model, lr=settings.learn_rate)