From dbadb1729aa44557fba7ceccb49331f5739b862d Mon Sep 17 00:00:00 2001 From: Balthazar Neveu Date: Sat, 16 Mar 2024 16:38:41 +0100 Subject: [PATCH] #9 add losses and metrics --- src/pixr/learning/loss.py | 25 ++++++++++++++++++++ src/pixr/learning/metrics.py | 46 ++++++++++++++++++++++++++++++++++++ src/pixr/properties.py | 24 +++++++++++++++++++ 3 files changed, 95 insertions(+) create mode 100644 src/pixr/learning/loss.py create mode 100644 src/pixr/learning/metrics.py diff --git a/src/pixr/learning/loss.py b/src/pixr/learning/loss.py new file mode 100644 index 0000000..9354dc7 --- /dev/null +++ b/src/pixr/learning/loss.py @@ -0,0 +1,25 @@ +import torch +from typing import Optional +from pixr.properties import LOSS_MSE + + +def compute_loss( + predic: torch.Tensor, + target: torch.Tensor, + mode: Optional[str] = LOSS_MSE +) -> torch.Tensor: + """ + Compute loss based on the predicted and true values. + + Args: + predic (torch.Tensor): [N, C, H, W] predicted values + target (torch.Tensor): [N, C, H, W] target values. + mode (Optional[str], optional): mode of loss computation. + + Returns: + torch.Tensor: The computed loss. + """ + assert mode in [LOSS_MSE], f"Mode {mode} not supported" + if mode == LOSS_MSE: + loss = torch.nn.functional.mse_loss(predic, target) + return loss diff --git a/src/pixr/learning/metrics.py b/src/pixr/learning/metrics.py new file mode 100644 index 0000000..d27b226 --- /dev/null +++ b/src/pixr/learning/metrics.py @@ -0,0 +1,46 @@ +import torch +from pixr.properties import METRIC_PSNR, REDUCTION_AVERAGE, REDUCTION_SKIP + + +def compute_psnr( + predic: torch.Tensor, + target: torch.Tensor, + clamp_mse=1e-10, + reduction=REDUCTION_AVERAGE +) -> torch.Tensor: + """ + Compute the average PSNR metric for a batch of predicted and true values. + + Args: + predic (torch.Tensor): [N, C, H, W] predicted values. + target (torch.Tensor): [N, C, H, W] target values. + + Returns: + torch.Tensor: The average PSNR value for the batch. + """ + mse_per_image = torch.mean((predic - target) ** 2, dim=(-3, -2, -1)) + mse_per_image = torch.clamp(mse_per_image, min=clamp_mse) + psnr_per_image = 10 * torch.log10(1 / mse_per_image) + if reduction == REDUCTION_AVERAGE: + average_psnr = torch.mean(psnr_per_image) + elif reduction == REDUCTION_SKIP: + average_psnr = psnr_per_image + else: + raise ValueError(f"Unknown reduction {reduction}") + return average_psnr + + +def compute_metrics(predic: torch.Tensor, target: torch.Tensor, reduction=REDUCTION_AVERAGE) -> dict: + """ + Compute the metrics for a batch of predicted and true values. + + Args: + predic (torch.Tensor): [N, C, H, W] predicted values. + target (torch.Tensor): [N, C, H, W] target values. + + Returns: + dict: computed metrics. + """ + average_psnr = compute_psnr(predic, target, reduction=reduction) + metrics = {METRIC_PSNR: average_psnr.item() if reduction != REDUCTION_SKIP else average_psnr} + return metrics diff --git a/src/pixr/properties.py b/src/pixr/properties.py index 8e1d74a..3a64a74 100644 --- a/src/pixr/properties.py +++ b/src/pixr/properties.py @@ -8,3 +8,27 @@ X = 0 Y = 1 Z = 2 + +OUTPUT_FOLDER_NAME = "__output" +DATALOADER = "data_loader" +BATCH_SIZE = "batch_size" +TRAIN, VALIDATION, TEST = "train", "validation", "test" +ID = "id" +NAME = "name" +NB_EPOCHS = "nb_epochs" +ARCHITECTURE = "architecture" +MODEL = "model" +NAME = "name" +N_PARAMS = "n_params" +OPTIMIZER = "optimizer" +LR = "lr" +PARAMS = "parameters" +SCHEDULER_CONFIGURATION = "scheduler_configuration" +SCHEDULER = "scheduler" +REDUCELRONPLATEAU = "ReduceLROnPlateau" +LOSS = "loss" +LOSS_MSE = "mse" +METRICS = "metrics" +METRIC_PSNR = "psnr" +REDUCTION_AVERAGE = "average" +REDUCTION_SKIP = "skip"