Skip to content

Commit

Permalink
restructure optimierInBackward class, combine self.optimziers and sel…
Browse files Browse the repository at this point in the history
…f.plain_optimizers
  • Loading branch information
mori360 committed Dec 17, 2024
1 parent fa4eef9 commit bcb144c
Showing 1 changed file with 4 additions and 50 deletions.
54 changes: 4 additions & 50 deletions torchtitan/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,6 @@ def __init__(
else:
raise NotImplementedError(f"Optimizer {name} not added.")
self.optimizers.append(optimizer)
self.plain_optim = (
[self.optimizers]
if isinstance(self.optimizers, torch.optim.Optimizer)
else self.optimizers
)

def step(self) -> None:
for optimizer in self.optimizers:
Expand All @@ -58,9 +53,7 @@ def state_dict(self) -> Dict[str, Any]:
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
return {
k: v
for sd in map(func, self.model, self.plain_optim)
for k, v in sd.items()
k: v for sd in map(func, self.model, self.optimizers) for k, v in sd.items()
}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
Expand All @@ -69,7 +62,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
optim_state_dict=state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
list(map(func, self.model, self.plain_optim))
list(map(func, self.model, self.optimizers))


class OptimizersInBackwardContainer(OptimizersContainer):
Expand Down Expand Up @@ -103,10 +96,7 @@ def optim_hook(param) -> None:
if param.requires_grad:
param.register_post_accumulate_grad_hook(optim_hook)

self.optimizers.append([optim_dict[param] for param in model.parameters()])
self.plain_optim = [
sub_optim for optim_group in self.optimizers for sub_optim in optim_group
]
self.optimizers.extend([optim_dict[param] for param in model.parameters()])

def step(self) -> None:
pass
Expand Down Expand Up @@ -187,45 +177,9 @@ def get_lr_scheduler_state(self) -> Dict[str, Any]:
return state_dict


class SchedulersInBackwardContainer(SchedulersContainer):
"""Util for calling step on multiple learning rate schedulers when optimizers are in backward"""

def __init__(self, optimizers, lr_lambda) -> None:
# all the schedulers for each optimizer group are the same, here we only store the first one
# to self.schedulers follow the same structure as SchedulersContainer, but store all of them
# to self.all_schedulers for container.step() to call
self.schedulers = []
for optim_group in optimizers:
scheduler_group = []
for sub_optim in optim_group:
scheduler_group.append(LambdaLR(sub_optim, lr_lambda=lr_lambda))
self.schedulers.append(scheduler_group)

def step(self) -> None:
for scheduler_group in self.schedulers:
for scheduler in scheduler_group:
scheduler.step()

def get_lr_scheduler_state(self) -> Dict[str, Any]:
state_dict = {}
if len(self.schedulers) == 1:
state_dict["lr_scheduler"] = self.schedulers[0][0]
else:
# For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler.
# It should only support saving and loading a distributed checkpoint with the same number of pp ranks
for idx, lr_scheduler in enumerate(self.schedulers):
state_dict[f"lr_scheduler_{idx}"] = lr_scheduler[0]
return state_dict


def build_lr_schedulers(optimizers, job_config: JobConfig) -> SchedulersContainer:
optim_in_bwd = job_config.optimizer.early_step_in_backward
warmup_steps = int(job_config.training.warmup_steps)
decay_steps = float(max(1, job_config.training.steps - warmup_steps))
lr_lambda = functools.partial(linear_warmup_linear_decay, warmup_steps, decay_steps)

return (
SchedulersContainer(optimizers, lr_lambda)
if not optim_in_bwd
else SchedulersInBackwardContainer(optimizers, lr_lambda)
)
return SchedulersContainer(optimizers, lr_lambda)

0 comments on commit bcb144c

Please sign in to comment.