Skip to content

Commit

Permalink
#9 add losses and metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
balthazarneveu committed Mar 16, 2024
1 parent 7c218a8 commit dbadb17
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 0 deletions.
25 changes: 25 additions & 0 deletions src/pixr/learning/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch
from typing import Optional
from pixr.properties import LOSS_MSE


def compute_loss(
predic: torch.Tensor,
target: torch.Tensor,
mode: Optional[str] = LOSS_MSE
) -> torch.Tensor:
"""
Compute loss based on the predicted and true values.
Args:
predic (torch.Tensor): [N, C, H, W] predicted values
target (torch.Tensor): [N, C, H, W] target values.
mode (Optional[str], optional): mode of loss computation.
Returns:
torch.Tensor: The computed loss.
"""
assert mode in [LOSS_MSE], f"Mode {mode} not supported"
if mode == LOSS_MSE:
loss = torch.nn.functional.mse_loss(predic, target)
return loss
46 changes: 46 additions & 0 deletions src/pixr/learning/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
from pixr.properties import METRIC_PSNR, REDUCTION_AVERAGE, REDUCTION_SKIP


def compute_psnr(
predic: torch.Tensor,
target: torch.Tensor,
clamp_mse=1e-10,
reduction=REDUCTION_AVERAGE
) -> torch.Tensor:
"""
Compute the average PSNR metric for a batch of predicted and true values.
Args:
predic (torch.Tensor): [N, C, H, W] predicted values.
target (torch.Tensor): [N, C, H, W] target values.
Returns:
torch.Tensor: The average PSNR value for the batch.
"""
mse_per_image = torch.mean((predic - target) ** 2, dim=(-3, -2, -1))
mse_per_image = torch.clamp(mse_per_image, min=clamp_mse)
psnr_per_image = 10 * torch.log10(1 / mse_per_image)
if reduction == REDUCTION_AVERAGE:
average_psnr = torch.mean(psnr_per_image)
elif reduction == REDUCTION_SKIP:
average_psnr = psnr_per_image
else:
raise ValueError(f"Unknown reduction {reduction}")
return average_psnr


def compute_metrics(predic: torch.Tensor, target: torch.Tensor, reduction=REDUCTION_AVERAGE) -> dict:
"""
Compute the metrics for a batch of predicted and true values.
Args:
predic (torch.Tensor): [N, C, H, W] predicted values.
target (torch.Tensor): [N, C, H, W] target values.
Returns:
dict: computed metrics.
"""
average_psnr = compute_psnr(predic, target, reduction=reduction)
metrics = {METRIC_PSNR: average_psnr.item() if reduction != REDUCTION_SKIP else average_psnr}
return metrics
24 changes: 24 additions & 0 deletions src/pixr/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,27 @@
X = 0
Y = 1
Z = 2

OUTPUT_FOLDER_NAME = "__output"
DATALOADER = "data_loader"
BATCH_SIZE = "batch_size"
TRAIN, VALIDATION, TEST = "train", "validation", "test"
ID = "id"
NAME = "name"
NB_EPOCHS = "nb_epochs"
ARCHITECTURE = "architecture"
MODEL = "model"
NAME = "name"
N_PARAMS = "n_params"
OPTIMIZER = "optimizer"
LR = "lr"
PARAMS = "parameters"
SCHEDULER_CONFIGURATION = "scheduler_configuration"
SCHEDULER = "scheduler"
REDUCELRONPLATEAU = "ReduceLROnPlateau"
LOSS = "loss"
LOSS_MSE = "mse"
METRICS = "metrics"
METRIC_PSNR = "psnr"
REDUCTION_AVERAGE = "average"
REDUCTION_SKIP = "skip"

0 comments on commit dbadb17

Please sign in to comment.