Skip to content

Commit

Permalink
#9 initiate training framework
Browse files Browse the repository at this point in the history
  • Loading branch information
balthazarneveu committed Mar 16, 2024
1 parent dbadb17 commit 65ccbbb
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 19 deletions.
3 changes: 3 additions & 0 deletions scripts/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pathlib import Path
SAMPLE_SCENES = Path("__data")
OUT_DIR = Path("__output")
50 changes: 50 additions & 0 deletions scripts/experiments_definition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from pixr.properties import (NB_EPOCHS, TRAIN, VALIDATION, SCHEDULER, REDUCELRONPLATEAU,
MODEL, ARCHITECTURE, ID, NAME, SCHEDULER_CONFIGURATION, OPTIMIZER, PARAMS, LR,
LOSS, LOSS_MSE)
from typing import Tuple


def model_configurations(config, model_preset="UNet") -> dict:
if model_preset == "UNet":
config[MODEL] = {
ARCHITECTURE: dict(
width=64,
enc_blk_nums=[1, 1, 1, 28],
middle_blk_num=1,
dec_blk_nums=[1, 1, 1, 1],
),
NAME: model_preset
}
else:
raise ValueError(f"Unknown model preset {model_preset}")


def presets_experiments(
exp: int,
n: int = 50,
model_preset: str = "UNet"
) -> dict:
config = {
ID: exp,
NAME: f"{exp:04d}",
NB_EPOCHS: n
}
config[OPTIMIZER] = {
NAME: "Adam",
PARAMS: {
LR: 1e-3
}
}
model_configurations(config, model_preset=model_preset)
config[SCHEDULER] = REDUCELRONPLATEAU
config[SCHEDULER_CONFIGURATION] = {
"factor": 0.8,
"patience": 5
}
config[LOSS] = LOSS_MSE
return config


def get_experiment(exp: int):
if exp == 1:
return presets_experiments(exp, n=50, model_preset="UNet")
40 changes: 40 additions & 0 deletions scripts/optimize_point_based_neural_renderer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from pixr.synthesis.world_simulation import STAIRCASE
from pixr.learning.utils import prepare_dataset
import torch

from config import OUT_DIR
from pixr.properties import DEVICE
import argparse
# from experiments_definition import
DEFAULT_CONFIG = {}


def main(out_root=OUT_DIR, name=STAIRCASE, device=DEVICE, show=True, save=False, config: dict = DEFAULT_CONFIG):
train_material, valid_material, (w, h), point_cloud_material = prepare_dataset(out_root, name)
# Move training data to GPU
wc_points, wc_normals = point_cloud_material
wc_points = wc_points.to(device)
wc_normals = wc_normals.to(device)
rendered_view_train, camera_intrinsics_train, camera_extrinsics_train = train_material
rendered_view_train = rendered_view_train.to(device)
camera_intrinsics_train = camera_intrinsics_train.to(device)
camera_extrinsics_train = camera_extrinsics_train.to(device)
# Validation images can remain on CPU
rendered_view_valid, camera_intrinsics_valid, camera_extrinsics_valid = valid_material
rendered_view_valid = rendered_view_valid.cpu()
n_steps = 200

# # Intitialize trainable parameters
# optimizer
# optimizer = torch.optim.Adam([color_pred], lr=0.3)
# for step in range(n_steps):
# optimizer.zero_grad()
# loss = 0.


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Render a scene using BlenderProc")
parser.add_argument("-s", "--scene", type=str, help="Name of the scene to render", default=STAIRCASE)
parser.add_argument("-e", "--export", action="store_true", help="Export validation image")
args = parser.parse_args()
main(name=args.scene, show=False, save=args.export)
14 changes: 14 additions & 0 deletions src/pixr/learning/architecture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from pixr.properties import MODEL, NAME, N_PARAMS, ARCHITECTURE
from pixr.learning.unet import UNet
import torch


def load_architecture(config: dict) -> torch.nn.Module:
conf_model = config[MODEL][ARCHITECTURE]
if config[MODEL][NAME] == UNet.__name__:
model = UNet(**conf_model)
else:
raise ValueError(f"Unknown model {config[MODEL][NAME]}")
config[MODEL][N_PARAMS] = model.count_parameters()
config[MODEL]["receptive_field"] = model.receptive_field()
return model
56 changes: 56 additions & 0 deletions src/pixr/learning/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
from pixr.properties import LEAKY_RELU, RELU, SIMPLE_GATE
from typing import Optional, Tuple


class SimpleGate(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=1)
return x1 * x2


def get_non_linearity(activation: str):
if activation == LEAKY_RELU:
non_linearity = torch.nn.LeakyReLU()
elif activation == RELU:
non_linearity = torch.nn.ReLU()
elif activation is None:
non_linearity = torch.nn.Identity()
elif activation == SIMPLE_GATE:
non_linearity = SimpleGate()
else:
raise ValueError(f"Unknown activation {activation}")
return non_linearity


class BaseModel(torch.nn.Module):
"""Base class for all restoration models with additional useful methods"""

def count_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)

def receptive_field(
self,
channels: Optional[int] = 3,
size: Optional[int] = 256,
device: Optional[str] = None
) -> Tuple[int, int]:
"""Compute the receptive field of the model
Returns:
int: receptive field
"""
input_tensor = torch.ones(1, channels, size, size, requires_grad=True)
if device is not None:
input_tensor = input_tensor.to(device)
out = self.forward(input_tensor)
grad = torch.zeros_like(out)
grad[..., out.shape[-2]//2, out.shape[-1]//2] = torch.nan # set NaN gradient at the middle of the output
out.backward(gradient=grad)
self.zero_grad()
receptive_field_mask = input_tensor.grad.isnan()[0, 0]
receptive_field_indexes = torch.where(receptive_field_mask)
# Count NaN in the input
receptive_x = 1+receptive_field_indexes[-1].max() - receptive_field_indexes[-1].min() # Horizontal x
receptive_y = 1+receptive_field_indexes[-2].max() - receptive_field_indexes[-2].min() # Vertical y
return receptive_x.item(), receptive_y.item()
15 changes: 15 additions & 0 deletions src/pixr/learning/experiments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from pixr.properties import DEVICE, OPTIMIZER, PARAMS
from pixr.learning.architecture import load_architecture
from typing import Tuple
import torch


def get_training_content(
config: dict,
training_mode: bool = False,
) -> Tuple[torch.nn.Module, torch.optim.Optimizer, dict]:
model = load_architecture(config)
optimizer = None
if training_mode:
optimizer = torch.optim.Adam(model.parameters(), **config[OPTIMIZER][PARAMS])
return model, optimizer
15 changes: 15 additions & 0 deletions src/pixr/learning/unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from pixr.learning.base import BaseModel
import torch.nn as nn


class UNet(BaseModel):
def __init__(self, in_channels=3, out_channels=3, k_size=3, encoder=[1, 1, 1, 1], decoder=[1, 1, 1, 1], thickness=4, **kwargs):
super().__init__()

self.conv1 = nn.Conv2d(in_channels, out_channels, k_size, padding=k_size//2)
self.in_channels = in_channels
self.out_channels = out_channels

def forward(self, x):
x = self.conv1(x)
return x
3 changes: 3 additions & 0 deletions src/pixr/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,6 @@
METRIC_PSNR = "psnr"
REDUCTION_AVERAGE = "average"
REDUCTION_SKIP = "skip"
LEAKY_RELU = "LeakyReLU"
RELU = "ReLU"
SIMPLE_GATE = "SimpleGate"
19 changes: 0 additions & 19 deletions studies/optimize_point_based_neural_renderer.py

This file was deleted.

0 comments on commit 65ccbbb

Please sign in to comment.