Skip to content

Commit

Permalink
Added new unet with timm backbones
Browse files Browse the repository at this point in the history
  • Loading branch information
rickycorte committed Jul 4, 2023
1 parent 2448c06 commit f00db09
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 10 deletions.
3 changes: 2 additions & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .baseline import UNetBaseline
from .tlmod import SimoidSegmentationModule
from .tlmod import SimoidSegmentationModule
from .timmunet import UnetTimm
55 changes: 55 additions & 0 deletions models/timmunet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
import timm
from torchvision import transforms

from models.baseline import UConvGroup

class UnetTimm(torch.nn.Module):
def __init__(self, out_depth:int, backbone_name="efficientnet_b0", pretrained=True, decoder_scale = 1):
super().__init__()
self.backbone = timm.create_model(
backbone_name,
features_only=True,
pretrained=pretrained
)

self.upconvs = []

# get channels of backbone layers in inverted order (lower -> upper)
bb_channels = self.backbone.feature_info.channels()[::-1]
bb_channels.append(bb_channels[-1])

for i in range(len(bb_channels)-1):
if i == 0:
layer = UConvGroup(bb_channels[i], decoder_scale * bb_channels[i+1])
else:
layer = UConvGroup((decoder_scale + 1) * bb_channels[i], decoder_scale * bb_channels[i+1])

self.upconvs.append(layer)

self.upconvs = torch.nn.ModuleList(self.upconvs)

self.normalize = transforms.Normalize(
mean=self.backbone.pretrained_cfg["mean"],
std=self.backbone.pretrained_cfg["std"],
)

self.out_conv = torch.nn.Conv2d(decoder_scale * bb_channels[-1], out_depth, kernel_size=3, padding=1)



def forward(self, x):
#x = self.normalize(x)
features = self.backbone(x)[::-1]

for i, f in enumerate(features):
if i == 0:
void_shape = list(f.shape)
void_shape[1] = 0
p = self.upconvs[0](torch.empty(void_shape).to(x.device), f)
else:
p = self.upconvs[i](p, f)

#print(f"{i}: {x.shape}")

return self.out_conv(p)
24 changes: 21 additions & 3 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,26 @@
from tqdm import tqdm

import settings
from models import UNetBaseline
from models import UNetBaseline, UnetTimm

available_models = {
"baseline": {
"model": UNetBaseline(in_depth=settings.input_channels, out_depth=1, depth_scale=settings.baseline_model_scale),
"checkpoint": settings.baseline_checkpoint_path
},
"tunet": {
"model": UnetTimm(out_depth=1, backbone_name="efficientnet_b3", pretrained=False, decoder_scale=settings.timmunet_decoder_scale),
"checkpoint": settings.timmunet_checkpoint_path
}

}

# command line args
parser = argparse.ArgumentParser(description='Segment pokemon cards')

parser.add_argument("-file", dest="file", help="Input image", type=str, default=None)
parser.add_argument("-folder", dest="folder", help="Folder where images (pngs or jpgs) are located", type=str, default=None)
parser.add_argument("-model", dest="model", help=f"Model name {available_models.keys()}", type=str, default="tunet")

args = parser.parse_args()

Expand Down Expand Up @@ -49,11 +62,16 @@
if ".png" in f or ".jpg" in f:
files.append(os.path.join(args.folder, f))

if args.model not in available_models.keys():
print(f"No model named: {args.file}. Available models: {available_models.keys()}")
exit(1)

print(f"Using model: {args.model}")

# load model
model = UNetBaseline(in_depth=settings.input_channels, out_depth=1, depth_scale=settings.baseline_model_scale)
model = available_models[args.model]["model"]

checkpoint = torch.load(settings.checkpoint_path)
checkpoint = torch.load(available_models[args.model]["checkpoint"])
weights = checkpoint["state_dict"]
for key in list(weights):
weights[key.replace("model.", "")] = weights.pop(key)
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ matplotlib==3.5.2
opencv-python-headless==4.7.0.72
pillow==9.5.0
tqdm>=4.64.0
torchvision==0.15.2
torchvision==0.15.2
timm==0.9.2
8 changes: 5 additions & 3 deletions settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
output_root = "output"

# global settings
card_size = (512, 368)
card_size = (512, 352)

# data preprocessing settings
min_mask_occupaion_percent = 0.20
Expand All @@ -30,11 +30,13 @@
max_epochs = 150
input_channels = 3
use_noisy_labels = False
learn_rate = 1e-3
learn_rate = 1e-4
baseline_model_scale = 2
timmunet_decoder_scale = 1

# prediction settings
checkpoint_path = "checkpoints/baseline_150.ckpt"
baseline_checkpoint_path = "checkpoints/baseline_150.ckpt"
timmunet_checkpoint_path = "checkpoints/timmunet_eff3_0819.ckpt"

################################################################################

Expand Down
10 changes: 8 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import settings
from datasets import PkmCardSegmentationDataModule

from models import SimoidSegmentationModule, UNetBaseline
from models import SimoidSegmentationModule, UnetTimm


# data agumentation
Expand All @@ -25,7 +25,13 @@
use_noisy=settings.use_noisy_labels
)

torch_model = UNetBaseline(in_depth=settings.input_channels, out_depth=1, depth_scale=settings.baseline_model_scale)
# replace this line with the model you want to train
torch_model = UnetTimm(
out_depth=1,
backbone_name="efficientnet_b3",
pretrained=True,
decoder_scale=settings.timmunet_decoder_scale
)


model = SimoidSegmentationModule(torch_model, lr=settings.learn_rate)
Expand Down

0 comments on commit f00db09

Please sign in to comment.