Skip to content

Commit

Permalink
Complete FSDP-ready GradMonitor
Browse files Browse the repository at this point in the history
  • Loading branch information
Iacob-Alexandru-Andrei committed Jan 15, 2025
1 parent 5af6f93 commit c45a185
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 45 deletions.
76 changes: 44 additions & 32 deletions composer/callbacks/grad_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,18 @@

"""Monitor gradients during DDP training."""

import warnings
from typing import Union
import torch

from composer.core import Callback, State, Time, TimeUnit
from composer.core import Callback, State
from composer.loggers import Logger
from composer.utils import dist

__all__ = ['GradMonitor']

__all__ = ["GradMonitor"]

class GradMonitor(Callback):
"""extracts gradients from the self.state.model during training.
The extracted gradients are stored in the self.state.grads attribute, in the form of a list of tensors.
The extracted gradients are stored in the self.self.grads attribute, in the form of a list of tensors.
Example:
.. doctest::
Expand All @@ -38,45 +35,60 @@ class GradMonitor(Callback):

def __init__(
self,
)-> None:
self.num_microbatches = 0
) -> None:
self.executed_steps = 0
# NOTE: May want to make this configurable
self.device = torch.device("cpu")

def _extract_grads(self, state: State, device: torch.device = torch.device('cpu')) -> None:
def _extract_grads(
self, state: State
) -> None:
"""Extracts gradients of each batch from the model
A running average of the gradients is stored in the state.
Args:
state (State): The state object.
device (torch.device, optional): The device to store the gradients. Defaults to CPU.
"""
group = list(state.model.parameters())
grad_list = []
for p in group:
if p.grad is not None:
grad_list.append(p.grad.detach().clone().to(device))

group = list(state.model.named_parameters())
grad_dict: dict[str, torch.Tensor] = {}
for name, p in group:
if p.grad is not None and p.requires_grad:
grad_dict[name] = p.grad.to(self.device).detach().clone()

# average the gradients
prev_grads = state.grads
if prev_grads:
aver_grad_list = [(prev_grads[i] * self.executed_steps + grad_list[i]) / (self.executed_steps + 1) for i in range(len(prev_grads))]
else: # the first batch, no need to average
aver_grad_list = grad_list

aver_grad_dict = {
name: (
(prev_grads[name] * self.executed_steps + grad_dict[name])
/ (self.executed_steps + 1)
)
for name in prev_grads
}
else: # the first batch, no need to average
aver_grad_dict = grad_dict
self.executed_steps = self.executed_steps + 1

if self.executed_steps == state.local_steps: # averaged gradients will be sent to the cloud, so we can reset the counter
self.executed_steps = 0

state.grads = aver_grad_list

state.grads = aver_grad_dict

def after_backward(self, state: State, logger: Logger) -> None:
"""Runs on ``Event.AFTER_BACKWARD`` in the function of _train_microbatch.
"""
assert state.total_num_microbatches is not None, "The total number of microbatch must be set"
self.num_microbatches = self.num_microbatches + 1
if self.num_microbatches == state.total_num_microbatches:
self.num_microbatches = 0
self._extract_grads(state)
"""Extract gradients on event ``Event.AFTER_BACKWARD`` in the function of _train_microbatch."""
# NOTE: Inefficient, if we know when the last microbatch is called we can avoid this
# TODO: Lorenzo, lmk where we can set the number of microbatches
self._extract_grads(state)

def batch_end(self, state: State, logger: Logger) -> None:
"""Sync gradient store on ``Event.BATCH_END`` in the function of _train_microbatch."""
assert state.grads is not None, "state.grads should not be None if this callback is used"
if dist.is_initialized() and dist.get_world_size() > 1:
for name in state.grads.keys():
last_grad = state.grads[name]

dist.all_reduce(last_grad)
last_grad.div_(dist.get_world_size())

# Should not be necessary, but just in case
state.grads[name] = last_grad

17 changes: 4 additions & 13 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,20 +496,8 @@ def __init__(
parallelism_config: Optional[ParallelismConfig] = None,

# Is the model of the fine-tuning type
is_model_finetune: bool = False,

# average model grads of all batches in one global round
grads: Optional[list] = None,

# Microbatch numbers
total_num_microbatches: int | None = None,

# local steps
local_steps: int | None = 1,
is_model_finetune: bool = False,
):
self.grads = grads
self.total_num_microbatches = total_num_microbatches
self.local_steps = local_steps
self.rank_zero_seed = rank_zero_seed
self.model = model
self.run_name = run_name
Expand Down Expand Up @@ -603,6 +591,9 @@ def __init__(
self.total_loss_dict: dict[str, float] = {}

self.metric_outputs: dict[str, Any] = {}

# average model grads of all batches in one global round
self.grads: Optional[dict[str, torch.Tensor]] = None

def _validate_parallelism_configs(self):
# Validate TP config
Expand Down

0 comments on commit c45a185

Please sign in to comment.