-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
dbadb17
commit 65ccbbb
Showing
9 changed files
with
196 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from pathlib import Path | ||
SAMPLE_SCENES = Path("__data") | ||
OUT_DIR = Path("__output") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from pixr.properties import (NB_EPOCHS, TRAIN, VALIDATION, SCHEDULER, REDUCELRONPLATEAU, | ||
MODEL, ARCHITECTURE, ID, NAME, SCHEDULER_CONFIGURATION, OPTIMIZER, PARAMS, LR, | ||
LOSS, LOSS_MSE) | ||
from typing import Tuple | ||
|
||
|
||
def model_configurations(config, model_preset="UNet") -> dict: | ||
if model_preset == "UNet": | ||
config[MODEL] = { | ||
ARCHITECTURE: dict( | ||
width=64, | ||
enc_blk_nums=[1, 1, 1, 28], | ||
middle_blk_num=1, | ||
dec_blk_nums=[1, 1, 1, 1], | ||
), | ||
NAME: model_preset | ||
} | ||
else: | ||
raise ValueError(f"Unknown model preset {model_preset}") | ||
|
||
|
||
def presets_experiments( | ||
exp: int, | ||
n: int = 50, | ||
model_preset: str = "UNet" | ||
) -> dict: | ||
config = { | ||
ID: exp, | ||
NAME: f"{exp:04d}", | ||
NB_EPOCHS: n | ||
} | ||
config[OPTIMIZER] = { | ||
NAME: "Adam", | ||
PARAMS: { | ||
LR: 1e-3 | ||
} | ||
} | ||
model_configurations(config, model_preset=model_preset) | ||
config[SCHEDULER] = REDUCELRONPLATEAU | ||
config[SCHEDULER_CONFIGURATION] = { | ||
"factor": 0.8, | ||
"patience": 5 | ||
} | ||
config[LOSS] = LOSS_MSE | ||
return config | ||
|
||
|
||
def get_experiment(exp: int): | ||
if exp == 1: | ||
return presets_experiments(exp, n=50, model_preset="UNet") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from pixr.synthesis.world_simulation import STAIRCASE | ||
from pixr.learning.utils import prepare_dataset | ||
import torch | ||
|
||
from config import OUT_DIR | ||
from pixr.properties import DEVICE | ||
import argparse | ||
# from experiments_definition import | ||
DEFAULT_CONFIG = {} | ||
|
||
|
||
def main(out_root=OUT_DIR, name=STAIRCASE, device=DEVICE, show=True, save=False, config: dict = DEFAULT_CONFIG): | ||
train_material, valid_material, (w, h), point_cloud_material = prepare_dataset(out_root, name) | ||
# Move training data to GPU | ||
wc_points, wc_normals = point_cloud_material | ||
wc_points = wc_points.to(device) | ||
wc_normals = wc_normals.to(device) | ||
rendered_view_train, camera_intrinsics_train, camera_extrinsics_train = train_material | ||
rendered_view_train = rendered_view_train.to(device) | ||
camera_intrinsics_train = camera_intrinsics_train.to(device) | ||
camera_extrinsics_train = camera_extrinsics_train.to(device) | ||
# Validation images can remain on CPU | ||
rendered_view_valid, camera_intrinsics_valid, camera_extrinsics_valid = valid_material | ||
rendered_view_valid = rendered_view_valid.cpu() | ||
n_steps = 200 | ||
|
||
# # Intitialize trainable parameters | ||
# optimizer | ||
# optimizer = torch.optim.Adam([color_pred], lr=0.3) | ||
# for step in range(n_steps): | ||
# optimizer.zero_grad() | ||
# loss = 0. | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Render a scene using BlenderProc") | ||
parser.add_argument("-s", "--scene", type=str, help="Name of the scene to render", default=STAIRCASE) | ||
parser.add_argument("-e", "--export", action="store_true", help="Export validation image") | ||
args = parser.parse_args() | ||
main(name=args.scene, show=False, save=args.export) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from pixr.properties import MODEL, NAME, N_PARAMS, ARCHITECTURE | ||
from pixr.learning.unet import UNet | ||
import torch | ||
|
||
|
||
def load_architecture(config: dict) -> torch.nn.Module: | ||
conf_model = config[MODEL][ARCHITECTURE] | ||
if config[MODEL][NAME] == UNet.__name__: | ||
model = UNet(**conf_model) | ||
else: | ||
raise ValueError(f"Unknown model {config[MODEL][NAME]}") | ||
config[MODEL][N_PARAMS] = model.count_parameters() | ||
config[MODEL]["receptive_field"] = model.receptive_field() | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import torch | ||
from pixr.properties import LEAKY_RELU, RELU, SIMPLE_GATE | ||
from typing import Optional, Tuple | ||
|
||
|
||
class SimpleGate(torch.nn.Module): | ||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x1, x2 = x.chunk(2, dim=1) | ||
return x1 * x2 | ||
|
||
|
||
def get_non_linearity(activation: str): | ||
if activation == LEAKY_RELU: | ||
non_linearity = torch.nn.LeakyReLU() | ||
elif activation == RELU: | ||
non_linearity = torch.nn.ReLU() | ||
elif activation is None: | ||
non_linearity = torch.nn.Identity() | ||
elif activation == SIMPLE_GATE: | ||
non_linearity = SimpleGate() | ||
else: | ||
raise ValueError(f"Unknown activation {activation}") | ||
return non_linearity | ||
|
||
|
||
class BaseModel(torch.nn.Module): | ||
"""Base class for all restoration models with additional useful methods""" | ||
|
||
def count_parameters(self): | ||
return sum(p.numel() for p in self.parameters() if p.requires_grad) | ||
|
||
def receptive_field( | ||
self, | ||
channels: Optional[int] = 3, | ||
size: Optional[int] = 256, | ||
device: Optional[str] = None | ||
) -> Tuple[int, int]: | ||
"""Compute the receptive field of the model | ||
Returns: | ||
int: receptive field | ||
""" | ||
input_tensor = torch.ones(1, channels, size, size, requires_grad=True) | ||
if device is not None: | ||
input_tensor = input_tensor.to(device) | ||
out = self.forward(input_tensor) | ||
grad = torch.zeros_like(out) | ||
grad[..., out.shape[-2]//2, out.shape[-1]//2] = torch.nan # set NaN gradient at the middle of the output | ||
out.backward(gradient=grad) | ||
self.zero_grad() | ||
receptive_field_mask = input_tensor.grad.isnan()[0, 0] | ||
receptive_field_indexes = torch.where(receptive_field_mask) | ||
# Count NaN in the input | ||
receptive_x = 1+receptive_field_indexes[-1].max() - receptive_field_indexes[-1].min() # Horizontal x | ||
receptive_y = 1+receptive_field_indexes[-2].max() - receptive_field_indexes[-2].min() # Vertical y | ||
return receptive_x.item(), receptive_y.item() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from pixr.properties import DEVICE, OPTIMIZER, PARAMS | ||
from pixr.learning.architecture import load_architecture | ||
from typing import Tuple | ||
import torch | ||
|
||
|
||
def get_training_content( | ||
config: dict, | ||
training_mode: bool = False, | ||
) -> Tuple[torch.nn.Module, torch.optim.Optimizer, dict]: | ||
model = load_architecture(config) | ||
optimizer = None | ||
if training_mode: | ||
optimizer = torch.optim.Adam(model.parameters(), **config[OPTIMIZER][PARAMS]) | ||
return model, optimizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from pixr.learning.base import BaseModel | ||
import torch.nn as nn | ||
|
||
|
||
class UNet(BaseModel): | ||
def __init__(self, in_channels=3, out_channels=3, k_size=3, encoder=[1, 1, 1, 1], decoder=[1, 1, 1, 1], thickness=4, **kwargs): | ||
super().__init__() | ||
|
||
self.conv1 = nn.Conv2d(in_channels, out_channels, k_size, padding=k_size//2) | ||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.