Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor regularization step closes #117 #200

Merged
merged 18 commits into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ python -m art.cli run-dashboard
2. A tutorial showing how to use ART for transfer learning in an NLP task.
```sh
python -m art.cli bert-transfer-learning-tutorial

```
3. A tutorial showing how to use ART for regularization
```sh
python -m art.cli regularization_tutorial
```

## Contributing
Expand Down
9 changes: 9 additions & 0 deletions art/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,14 @@ def bert_transfer_learning_tutorial():
)


@app.command()
def regularization_tutorial():
"""Creates a regularize tutorial."""
create_project(
project_name="regularize_tutorial",
branch="regularize_tutorial",
)


if __name__ == "__main__":
app()
40 changes: 5 additions & 35 deletions art/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.utils.data import DataLoader

from art.metrics import MetricCalculator
from art.utils.enums import LOSS, PREDICTION, TARGET
from art.utils.enums import LOSS, PREDICTION, TARGET, TrainingStage


class ArtModule(L.LightningModule, ABC):
Expand All @@ -18,6 +18,7 @@ def __init__(
super().__init__()
self.regularized = True
self.set_pipelines()
self.stage: TrainingStage = TrainingStage.TRAIN

"""
A module for managing the training process and application of various model configurations.
Expand Down Expand Up @@ -60,40 +61,6 @@ def set_pipelines(self):
]
self.ml_train_pipeline = [self.ml_parse_data, self.baseline_train]

def turn_on_model_regularizations(self):
"""
Turn on model regularizations.
"""
if not self.regularized:
for param in self.parameters():
name, obj = param
if isinstance(obj, torch.nn.Dropout):
obj.p = self.unregularized_params[name]

self.configure_optimizers = self.original_configure_optimizers

self.regularized = True

def turn_off_model_reguralizations(self):
"""
Turn off model regularizations.
"""
if self.regularized:
self.unregularized_params = {}
for param in self.parameters():
name, obj = param
if isinstance(obj, torch.nn.Dropout):
self.unregularized_params[name] = obj.p
obj.p = 0

# Simple Adam, no fancy optimizers at this stage
self.original_configure_optimizers = self.configure_optimizers
self.configure_optimizers = lambda self: torch.optim.Adam(
self.parameters(), lr=3e-4
)

self.regularized = False

def parse_data(self, data: Dict):
"""
Parse data.
Expand Down Expand Up @@ -153,6 +120,7 @@ def validation_step(
batch (Union[Dict[str, Any], DataLoader, torch.Tensor]): Batch to validate.
batch_idx (int): Batch index.
"""
self.stage = TrainingStage.VALIDATION
data = {"batch": batch, "batch_idx": batch_idx}
for func in self.validation_step_pipeline:
data = func(data)
Expand All @@ -170,6 +138,7 @@ def training_step(
Returns:
Dict: Data with loss.
"""
self.stage = TrainingStage.TRAIN
data = {"batch": batch, "batch_idx": batch_idx}
for func in self.train_step_pipeline:
data = func(data)
Expand All @@ -186,6 +155,7 @@ def test_step(
batch (Union[Dict[str, Any], DataLoader, torch.Tensor]): Batch to test.
batch_idx (int): Batch index.
"""
self.stage = TrainingStage.TEST
data = {"batch": batch, "batch_idx": batch_idx}
for func in self.validation_step_pipeline:
data = func(data)
Expand Down
29 changes: 29 additions & 0 deletions art/dashboard/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ def prepare_steps_info(logs_path: Path) -> Dict[str, Dict]:
step_name = step_info["name"]
step_model = step_info["model"]
for run in step_info["runs"]:
if "regularize" in run["parameters"]:
run["parameters"]["regularize"] = stringify_regularize(
run["parameters"]["regularize"]
)

new_sample = {
"model": step_model,
**run["scores"],
Expand All @@ -53,6 +58,30 @@ def prepare_steps_info(logs_path: Path) -> Dict[str, Dict]:
return steps_info


def stringify_regularize(regularize: Dict) -> str:
"""Since regularize field contain list we must handle them with special care .

Args:
regularize (Dict): regularize field from results.json

Returns:
str: stringified version of regularize field
"""
parameters = []
for key, value in regularize.items():
if key in ["model_modifiers", "datamodule_modifiers"]:
continue
parameters.append(f"{key}={value}")
representation = ""
if parameters:
representation += f"model-kwargs={' '.join(parameters)} |"
if regularize["model_modifiers"]:
representation += f"model-modifiers={regularize['model_modifiers']} |"
if regularize["datamodule_modifiers"]:
representation += f"datamodule-modifiers={regularize['datamodule_modifiers']}"
return representation


def prepare_steps():
return [
"Data analysis",
Expand Down
19 changes: 15 additions & 4 deletions art/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple

from torchvision.utils import save_image

from art.loggers import art_logger, supress_stdout
from art.utils.enums import INPUT, PREDICTION, TARGET

Expand Down Expand Up @@ -93,6 +91,14 @@ def __init__(self, how_many_batches=10, image_key_name=INPUT):
how_many_batches (int, optional): How many batches to save. Defaults to 10.
image_key_name (str, optional): under what . Defaults to "input".
"""
try:
from torchvision.utils import save_image
except ImportError:
raise ImportError(
"You need to install torchvision to use BatchSaver decorator"
)

self.save_image = save_image
self.time = 0
self.how_many_batches = how_many_batches
self.image_key_name = image_key_name
Expand All @@ -109,7 +115,7 @@ def __call__(self, data: Dict):
img_ = data[self.image_key_name]
min_, max_ = img_.min(), img_.max()
img_ = ((img_ - min_) / (max_ - min_)) * 255
save_image(img_, self.img_path / f"{self.time}.png")
self.save_image(img_, self.img_path / f"{self.time}.png")
self.time += 1


Expand All @@ -122,7 +128,12 @@ def __init__(self, suppress_stdout=True, custom_logger=None):
suppress_stdout (bool, optional): Whether to suppress stdout. Defaults to True.
custom_logger (_type_, optional): By default art_logger will be used. You can pass your custom logger if you want. Defaults to None.
"""
import lovely_tensors as lt
try:
import lovely_tensors as lt
except ImportError:
raise ImportError(
"You need to install lovely_tensors to use LogInputStats decorator"
)

lt.monkey_patch()
self.logger = art_logger if custom_logger is None else custom_logger
Expand Down
19 changes: 14 additions & 5 deletions art/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union

import numpy as np
from lightning.pytorch.loggers import NeptuneLogger, WandbLogger
from loguru import logger

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from loguru import Logger

Expand All @@ -21,7 +19,14 @@ class LoggerFlags(Enum):


logger.remove()
logger.add(sys.stdout, format="{message}", level="DEBUG", filter= lambda record: not record["extra"].get(LoggerFlags.SUPRESS_STDOUT.value, False))
logger.add(
sys.stdout,
format="{message}",
level="DEBUG",
filter=lambda record: not record["extra"].get(
LoggerFlags.SUPRESS_STDOUT.value, False
),
)


def get_run_id() -> str:
Expand All @@ -42,13 +47,17 @@ def remove_logger(logger_id: int):
art_logger.remove(logger_id)


def supress_stdout(current_logger: 'Logger') -> 'Logger':
def supress_stdout(current_logger: "Logger") -> "Logger":
return current_logger.bind(**{LoggerFlags.SUPRESS_STDOUT.value: True})


art_logger = logger


def log_yellow_warning(message: str):
art_logger.opt(ansi=True).warning(f"<yellow>{message}</yellow>")


class NeptuneLoggerAdapter(NeptuneLogger):
"""
This is a wrapper for LightningLogger for simplifying basic functionalities between different loggers.
Expand Down
13 changes: 12 additions & 1 deletion art/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ def __init__(
self.stages = stages


def build_metric_name(metric: Any, stage: str) -> str:
kordc marked this conversation as resolved.
Show resolved Hide resolved
"""
Builds a name for the metric based on its type and given training stage.

Args:
metric (Any): The metric being calculated.
stage (str): The current stage of training.
"""
return f"{metric.__class__.__name__}-{stage}"


class MetricCalculator:
"""
Facilitates the management and application of metrics during different stages of training.
Expand All @@ -51,7 +62,7 @@ def build_name(self, metric: Any) -> str:
"""
stage = self.experiment.state.get_current_stage()

return f"{metric.__class__.__name__}-{stage}"
return build_metric_name(metric, stage)

def to(self, device: str):
"""
Expand Down
17 changes: 12 additions & 5 deletions art/project.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Union

import lightning as L

Expand Down Expand Up @@ -165,6 +165,8 @@ def check_if_must_be_run(self, step: "Step", checks: List[Check]) -> bool:
"""
if not step.was_run():
return True
if step.check_if_already_tried():
return False
else:
step_current_hash = step.get_hash()
step_saved_hash = step.get_latest_run()["hash"]
Expand Down Expand Up @@ -249,8 +251,9 @@ def run_all(
self.check_checks(step, checks)
except CheckFailedException as e:
art_logger.warning(e)
step.save_to_disk()
break
if not step.continue_on_failure:
step.save_to_disk()
break

self.fill_step_states(step)
step.save_to_disk()
Expand All @@ -268,13 +271,17 @@ def print_summary(self):
"""
art_logger.info("Summary: ")
for step in self.steps:
art_logger.info(step["step"])
if not step["step"].is_successful():
step = step["step"]
art_logger.info(step)

if not step.is_successful() and not step.continue_on_failure:
break

if len(self.changed_steps) > 0:
art_logger.info(
f"Code of the following steps was changed: {', '.join(self.changed_steps)}\n Rerun could be needed."
)
art_logger.info("Explore all runs with `python -m art.cli run-dashboard`")

def get_steps(self):
"""
Expand Down
Loading
Loading