diff --git a/torchtitan/optimizer.py b/torchtitan/optimizer.py index 8da56eb1..4f30f289 100644 --- a/torchtitan/optimizer.py +++ b/torchtitan/optimizer.py @@ -38,11 +38,6 @@ def __init__( else: raise NotImplementedError(f"Optimizer {name} not added.") self.optimizers.append(optimizer) - self.plain_optim = ( - [self.optimizers] - if isinstance(self.optimizers, torch.optim.Optimizer) - else self.optimizers - ) def step(self) -> None: for optimizer in self.optimizers: @@ -58,9 +53,7 @@ def state_dict(self) -> Dict[str, Any]: options=StateDictOptions(flatten_optimizer_state_dict=True), ) return { - k: v - for sd in map(func, self.model, self.plain_optim) - for k, v in sd.items() + k: v for sd in map(func, self.model, self.optimizers) for k, v in sd.items() } def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @@ -69,7 +62,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: optim_state_dict=state_dict, options=StateDictOptions(flatten_optimizer_state_dict=True), ) - list(map(func, self.model, self.plain_optim)) + list(map(func, self.model, self.optimizers)) class OptimizersInBackwardContainer(OptimizersContainer): @@ -103,10 +96,7 @@ def optim_hook(param) -> None: if param.requires_grad: param.register_post_accumulate_grad_hook(optim_hook) - self.optimizers.append([optim_dict[param] for param in model.parameters()]) - self.plain_optim = [ - sub_optim for optim_group in self.optimizers for sub_optim in optim_group - ] + self.optimizers.extend([optim_dict[param] for param in model.parameters()]) def step(self) -> None: pass @@ -187,45 +177,9 @@ def get_lr_scheduler_state(self) -> Dict[str, Any]: return state_dict -class SchedulersInBackwardContainer(SchedulersContainer): - """Util for calling step on multiple learning rate schedulers when optimizers are in backward""" - - def __init__(self, optimizers, lr_lambda) -> None: - # all the schedulers for each optimizer group are the same, here we only store the first one - # to self.schedulers follow the same structure as SchedulersContainer, but store all of them - # to self.all_schedulers for container.step() to call - self.schedulers = [] - for optim_group in optimizers: - scheduler_group = [] - for sub_optim in optim_group: - scheduler_group.append(LambdaLR(sub_optim, lr_lambda=lr_lambda)) - self.schedulers.append(scheduler_group) - - def step(self) -> None: - for scheduler_group in self.schedulers: - for scheduler in scheduler_group: - scheduler.step() - - def get_lr_scheduler_state(self) -> Dict[str, Any]: - state_dict = {} - if len(self.schedulers) == 1: - state_dict["lr_scheduler"] = self.schedulers[0][0] - else: - # For now, pipeline-parallel with looped schedules does not support resharding for lr_scheduler. - # It should only support saving and loading a distributed checkpoint with the same number of pp ranks - for idx, lr_scheduler in enumerate(self.schedulers): - state_dict[f"lr_scheduler_{idx}"] = lr_scheduler[0] - return state_dict - - def build_lr_schedulers(optimizers, job_config: JobConfig) -> SchedulersContainer: - optim_in_bwd = job_config.optimizer.early_step_in_backward warmup_steps = int(job_config.training.warmup_steps) decay_steps = float(max(1, job_config.training.steps - warmup_steps)) lr_lambda = functools.partial(linear_warmup_linear_decay, warmup_steps, decay_steps) - return ( - SchedulersContainer(optimizers, lr_lambda) - if not optim_in_bwd - else SchedulersInBackwardContainer(optimizers, lr_lambda) - ) + return SchedulersContainer(optimizers, lr_lambda)