Skip to content

Commit

Permalink
Merge pull request #4 from caidongqi/fl-update
Browse files Browse the repository at this point in the history
 add callback hook to extract gradients during training.
  • Loading branch information
Iacob-Alexandru-Andrei authored Jan 15, 2025
2 parents 593fae8 + 3208aa0 commit 5af6f93
Show file tree
Hide file tree
Showing 7 changed files with 387 additions and 0 deletions.
2 changes: 2 additions & 0 deletions composer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from composer.callbacks.speed_monitor import SpeedMonitor
from composer.callbacks.system_metrics_monitor import SystemMetricsMonitor
from composer.callbacks.threshold_stopper import ThresholdStopper
from composer.callbacks.grad_monitor import GradMonitor

__all__ = [
'ActivationMonitor',
Expand All @@ -46,4 +47,5 @@
'MemorySnapshot',
'OOMObserver',
'LoadCheckpoint',
'GradMonitor',
]
82 changes: 82 additions & 0 deletions composer/callbacks/grad_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Monitor gradients during DDP training."""

import warnings
from typing import Union
import torch

from composer.core import Callback, State, Time, TimeUnit
from composer.loggers import Logger
from composer.utils import dist

__all__ = ['GradMonitor']


class GradMonitor(Callback):
"""extracts gradients from the self.state.model during training.
The extracted gradients are stored in the self.state.grads attribute, in the form of a list of tensors.
Example:
.. doctest::
>>> from composer import Trainer
>>> from composer.callbacks import GradMonitor
>>> # constructing trainer object with this callback
>>> trainer = Trainer(
... model=model,
... train_dataloader=train_dataloader,
... eval_dataloader=eval_dataloader,
... optimizers=optimizer,
... max_duration="1ep",
... callbacks=[GradMonitor()],
... )
"""

def __init__(
self,
)-> None:
self.num_microbatches = 0
self.executed_steps = 0

def _extract_grads(self, state: State, device: torch.device = torch.device('cpu')) -> None:
"""Extracts gradients of each batch from the model
A running average of the gradients is stored in the state.
Args:
state (State): The state object.
device (torch.device, optional): The device to store the gradients. Defaults to CPU.
"""

group = list(state.model.parameters())
grad_list = []
for p in group:
if p.grad is not None:
grad_list.append(p.grad.detach().clone().to(device))

# average the gradients
prev_grads = state.grads
if prev_grads:
aver_grad_list = [(prev_grads[i] * self.executed_steps + grad_list[i]) / (self.executed_steps + 1) for i in range(len(prev_grads))]
else: # the first batch, no need to average
aver_grad_list = grad_list

self.executed_steps = self.executed_steps + 1

if self.executed_steps == state.local_steps: # averaged gradients will be sent to the cloud, so we can reset the counter
self.executed_steps = 0

state.grads = aver_grad_list


def after_backward(self, state: State, logger: Logger) -> None:
"""Runs on ``Event.AFTER_BACKWARD`` in the function of _train_microbatch.
"""
assert state.total_num_microbatches is not None, "The total number of microbatch must be set"
self.num_microbatches = self.num_microbatches + 1
if self.num_microbatches == state.total_num_microbatches:
self.num_microbatches = 0
self._extract_grads(state)
12 changes: 12 additions & 0 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,19 @@ def __init__(

# Is the model of the fine-tuning type
is_model_finetune: bool = False,

# average model grads of all batches in one global round
grads: Optional[list] = None,

# Microbatch numbers
total_num_microbatches: int | None = None,

# local steps
local_steps: int | None = 1,
):
self.grads = grads
self.total_num_microbatches = total_num_microbatches
self.local_steps = local_steps
self.rank_zero_seed = rank_zero_seed
self.model = model
self.run_name = run_name
Expand Down
4 changes: 4 additions & 0 deletions composer/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
CosineAnnealingScheduler,
CosineAnnealingWarmRestartsScheduler,
CosineAnnealingWithWarmupScheduler,
ConstantWithLinearCooldownWithWarmupScheduler,
ConstantWithSqrtCooldownWithWarmupScheduler,
ExponentialScheduler,
LinearScheduler,
LinearWithWarmupScheduler,
Expand All @@ -41,6 +43,8 @@
'CosineAnnealingScheduler',
'CosineAnnealingWarmRestartsScheduler',
'CosineAnnealingWithWarmupScheduler',
'ConstantWithLinearCooldownWithWarmupScheduler',
'ConstantWithSqrtCooldownWithWarmupScheduler',
'ExponentialScheduler',
'LinearScheduler',
'LinearWithWarmupScheduler',
Expand Down
233 changes: 233 additions & 0 deletions composer/optim/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
'ConstantWithWarmupScheduler',
'LinearWithWarmupScheduler',
'CosineAnnealingWithWarmupScheduler',
'ConstantWithLinearCooldownWithWarmupScheduler',
'ConstantWithSqrtCooldownWithWarmupScheduler',
'PolynomialWithWarmupScheduler',
]

Expand Down Expand Up @@ -357,6 +359,41 @@ def __call__(self, state: State, ssr: float = 1.0):
return current_factor


class SqrtScheduler(ComposerScheduler):
"""
A scheduler that adjusts the learning rate following a square root decay schedule.
Args:
t_duration (Union[str, Time]): The duration over which the learning rate is decayed.
t_max (Union[str, Time]): The total duration of the schedule. Defaults to '1dur'.
Attributes:
t_duration (Time): The duration over which the learning rate is decayed.
t_max (Time): The total duration of the schedule.
"""

def __init__(self, t_duration: Union[str, Time] = '1dur', t_max: Union[str, Time] = '1dur') -> None:
self.t_max = Time.from_timestring(t_max) if isinstance(t_max, str) else t_max
self.t_duration = Time.from_timestring(t_duration) if isinstance(t_duration, str) else t_duration

def __call__(self, state: State, ssr: float = 1.0) -> float:
"""
Calculate the learning rate multiplier based on the current state and schedule.
Args:
state (State): The current state of training.
ssr (float, optional): The scale factor for the learning rate. Defaults to 1.0.
Returns:
float: The learning rate multiplier.
"""
t_max = _convert_time(self.t_max, state, ssr=ssr)
t_duration = _convert_time(self.t_duration, state, ssr=ssr)
current_time = state.timestamp.get(t_max.unit)
current_factor = 1 - math.sqrt((current_time - (t_max - t_duration)) / t_duration)
return current_factor


class ExponentialScheduler(ComposerScheduler):
r"""Decays the learning rate exponentially.
Expand Down Expand Up @@ -593,6 +630,44 @@ def _raise_if_warmup_and_max_incompatible(t_warmup: Time[int], t_max: Time[int])
)


def _raise_if_cooldown_and_max_incompatible(t_cooldown: Time[int], t_max: Time[int]):
"""Checks that t_cooldown and t_max have the same units.
_convert_time should be called on both `t_cooldown` and `t_max` before this function is called. As a a result, t_cooldown and t_max will not
be TimeUnit.EPOCH.
"""
assert t_cooldown.unit != TimeUnit.EPOCH and t_max.unit != TimeUnit.EPOCH, 't_cooldown and t_max cannot be in units of EPOCH'
if isinstance(t_cooldown, str):
t_cooldown = Time.from_timestring(t_cooldown)
if isinstance(t_max, str):
t_max = Time.from_timestring(t_max)
units_same = t_cooldown.unit == t_max.unit
if not units_same:
raise ValueError(
f'Cannot use cooldown scheduler with t_max {t_max} with units {t_max.unit} and t_cooldown {t_cooldown} with '
f'units {t_cooldown.unit}. t_cooldown and t_max must use the same units.',
)


def _raise_if_cooldown_plus_warmup_and_max_incompatible(t_sum: Time[int], t_max: Time[int]):
"""Checks that t_sum and t_max have the same units.
_convert_time should be called on both `t_sum` and `t_max` before this function is called. As a a result, t_sum and t_max will not
be TimeUnit.EPOCH.
"""
assert t_sum.unit != TimeUnit.EPOCH and t_max.unit != TimeUnit.EPOCH, 't_sum and t_max cannot be in units of EPOCH'
if isinstance(t_sum, str):
t_sum = Time.from_timestring(t_sum)
if isinstance(t_max, str):
t_max = Time.from_timestring(t_max)
units_same = t_sum.unit == t_max.unit
if not units_same:
raise ValueError(
f'Cannot use wrmup plus cooldown scheduler with t_max {t_max} with units {t_max.unit} and t_sum {t_sum} with '
f'units {t_sum.unit}. t_sum and t_max must use the same units.',
)


class MultiStepWithWarmupScheduler(ComposerScheduler):
r"""Decays the learning rate discretely at fixed milestones, with an initial warmup.
Expand Down Expand Up @@ -875,6 +950,164 @@ def __call__(self, state: State, ssr: float = 1.0):
return _cosine_anneal(x=frac_of_total, min_y=self.alpha_f)


class ConstantWithLinearCooldownWithWarmupScheduler(ComposerScheduler):
"""
A scheduler that maintains a constant learning rate with optional linear warmup and cooldown periods.
Inspired by https://arxiv.org/abs/2405.18392v3.
Args:
t_warmup (Union[str, Time]): The duration of the warmup period.
t_cooldown (Union[str, Time]): The duration of the cooldown period.
t_max (Union[str, Time]): The total duration of the warmup and cooldown periods. Defaults to '1dur'.
scale_warmup (bool, optional): If True, scales the learning rate during the warmup period. Defaults to False.
scale_cooldown (bool, optional): If True, scales the learning rate during the cooldown period. Defaults to False.
Attributes:
t_warmup (Union[str, Time]): The duration of the warmup period.
t_cooldown (Union[str, Time]): The duration of the cooldown period.
t_max (Union[str, Time]): The total duration of the warmup and cooldown periods.
scale_warmup (bool): If True, scales the learning rate during the warmup period.
scale_cooldown (bool): If True, scales the learning rate during the cooldown period.
warmup_scheduler (LinearScheduler): The scheduler used during the warmup period.
cooldown_scheduler (LinearScheduler): The scheduler used during the cooldown period.
"""

def __init__(
self,
t_warmup: Union[str, Time],
t_cooldown: Union[str, Time],
t_max: Union[str, Time] = '1dur',
scale_warmup: bool = False,
scale_cooldown: bool = False,
) -> None:
self.t_warmup = t_warmup
self.t_cooldown = t_cooldown
self.t_max = t_max
self.scale_warmup = scale_warmup
self.scale_cooldown = scale_cooldown
self.warmup_scheduler = LinearScheduler(alpha_i=0.0, alpha_f=1.0, t_max=t_warmup)
self.cooldown_scheduler = LinearScheduler(alpha_i=1.0, alpha_f=0.0, t_max=t_cooldown)

def __call__(self, state: State, ssr: float = 1.0) -> float:
"""
Calculate the learning rate multiplier based on the current state and schedule.
Args:
state (State): The current state of training.
ssr (float, optional): The scale factor for the learning rate. Defaults to 1.0.
Returns:
float: The learning rate multiplier.
"""
assert state.max_duration is not None, 'max_duration should be set whenever schedulers are invoked'

# Convert warmup, cooldown, and max durations to the appropriate time units
t_warmup = _convert_time(self.t_warmup, state)
t_cooldown = _convert_time(self.t_cooldown, state)
t_max = _convert_time(self.t_max, state, ssr=ssr)

# Raise errors if warmup, cooldown, or max durations are incompatible
_raise_if_warmup_and_max_incompatible(t_warmup, t_max)
_raise_if_cooldown_and_max_incompatible(t_cooldown, t_max)
_raise_if_cooldown_plus_warmup_and_max_incompatible(t_cooldown + t_warmup, t_max)
_raise_if_max_duration_exceeds_t_max(t_max, state)

# If within the warmup period, use the warmup scheduler
if state.timestamp < t_warmup:
if self.scale_warmup:
return self.warmup_scheduler(state, ssr)
return self.warmup_scheduler(state)

# If within the cooldown period, use the cooldown scheduler
if state.timestamp >= t_max - t_cooldown:
if self.scale_cooldown:
return self.cooldown_scheduler(state, ssr)
return self.cooldown_scheduler(state)

# Otherwise, return a constant learning rate multiplier of 1.0
return 1.0


class ConstantWithSqrtCooldownWithWarmupScheduler(ComposerScheduler):
"""
A scheduler that maintains a constant learning rate with optional linear warmup and (1-sqrt) cooldown periods.
Inspired by https://arxiv.org/abs/2405.18392v3.
Args:
t_warmup (Union[str, Time]): The duration of the warmup period.
t_cooldown (Union[str, Time]): The duration of the cooldown period.
t_max (Union[str, Time]): The total duration of the warmup and cooldown periods. Defaults to '1dur'.
scale_warmup (bool, optional): If True, scales the learning rate during the warmup period. Defaults to False.
scale_cooldown (bool, optional): If True, scales the learning rate during the cooldown period. Defaults to False.
Attributes:
t_warmup (Union[str, Time]): The duration of the warmup period.
t_cooldown (Union[str, Time]): The duration of the cooldown period.
t_max (Union[str, Time]): The total duration of the warmup and cooldown periods.
scale_warmup (bool): If True, scales the learning rate during the warmup period.
scale_cooldown (bool): If True, scales the learning rate during the cooldown period.
warmup_scheduler (LinearScheduler): The scheduler used during the warmup period.
cooldown_scheduler (SqrtScheduler): The scheduler used during the cooldown period.
"""

def __init__(
self,
t_warmup: Union[str, Time],
t_cooldown: Union[str, Time],
t_max: Union[str, Time] = '1dur',
scale_warmup: bool = False,
scale_cooldown: bool = False,
) -> None:
self.t_warmup = t_warmup
self.t_cooldown = t_cooldown
self.t_max = t_max
self.scale_warmup = scale_warmup
self.scale_cooldown = scale_cooldown
self.warmup_scheduler = LinearScheduler(alpha_i=0.0, alpha_f=1.0, t_max=t_warmup)
self.cooldown_scheduler = SqrtScheduler(t_max=t_max, t_duration=t_cooldown)

def __call__(self, state: State, ssr: float = 1.0) -> float:
"""
Calculate the learning rate multiplier based on the current state and schedule.
Args:
state (State): The current state of training.
ssr (float, optional): The scale factor for the learning rate. Defaults to 1.0.
Returns:
float: The learning rate multiplier.
"""
assert state.max_duration is not None, 'max_duration should be set whenever schedulers are invoked'

# Convert warmup, cooldown, and max durations to the appropriate time units
t_warmup = _convert_time(self.t_warmup, state)
t_cooldown = _convert_time(self.t_cooldown, state)
t_max = _convert_time(self.t_max, state, ssr=ssr)

# Raise errors if warmup, cooldown, or max durations are incompatible
_raise_if_warmup_and_max_incompatible(t_warmup, t_max)
_raise_if_cooldown_and_max_incompatible(t_cooldown, t_max)
_raise_if_cooldown_plus_warmup_and_max_incompatible(t_cooldown + t_warmup, t_max)
_raise_if_max_duration_exceeds_t_max(t_max, state)

# If within the warmup period, use the warmup scheduler
if state.timestamp < t_warmup:
if self.scale_warmup:
return self.warmup_scheduler(state, ssr)
return self.warmup_scheduler(state)

# If within the cooldown period, use the cooldown scheduler
if state.timestamp >= t_max - t_cooldown:
if self.scale_cooldown:
return self.cooldown_scheduler(state, ssr)
return self.cooldown_scheduler(state)

# Otherwise, return a constant learning rate multiplier of 1.0
return 1.0


class PolynomialWithWarmupScheduler(ComposerScheduler):
r"""Decays the learning rate according to a power of the fraction of training time left, with an initial warmup.
Expand Down
Loading

0 comments on commit 5af6f93

Please sign in to comment.