diff --git a/src/zeroband/config.py b/src/zeroband/config.py index 61ac6e56..a82a043a 100644 --- a/src/zeroband/config.py +++ b/src/zeroband/config.py @@ -6,9 +6,9 @@ from zeroband.checkpoint import CkptConfig from zeroband.data import DataConfig from zeroband.diloco import DilocoConfig -from zeroband.global_ddp import GlobalDDPConfig from zeroband.models.llama.model import AttnFnType from zeroband.optimizers import OptimizersConfig, AdamConfig +from zeroband.dpu import ACCOConfig class OptimConfig(BaseConfig): @@ -69,7 +69,7 @@ class Config(BaseConfig): # sub config diloco: DilocoConfig | None = None - global_ddp: GlobalDDPConfig | None = None + acco: ACCOConfig | None = None data: DataConfig = DataConfig() optim: OptimConfig = OptimConfig() train: TrainConfig diff --git a/src/zeroband/dpu.py b/src/zeroband/dpu.py new file mode 100644 index 00000000..d6c5161a --- /dev/null +++ b/src/zeroband/dpu.py @@ -0,0 +1,5 @@ +from typing import Optional +from pydantic_config import BaseConfig + +class ACCOConfig(BaseConfig): + theta_t_device: Optional[str] = None diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 87f4635d..2e535e61 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -83,6 +83,9 @@ def train(config: Config): assert batch_size % config.train.micro_bs == 0 gradient_accumulation_steps = batch_size // config.train.micro_bs + if config.acco: + assert gradient_accumulation_steps % 2 == 0, "ACCO requires gradient accumulation steps to be even" + gradient_accumulation_steps //= 2 if config.ckpt is not None and config.ckpt.interval is not None and config.diloco is not None: assert ( @@ -137,7 +140,7 @@ def train(config: Config): apply_ac_ckpt(model, num) elastic_device_mesh = ElasticDeviceMesh( - enable=config.diloco is not None, live_recovery_rank_src=config.ckpt.live_recovery_rank_src + enable=config.diloco is not None or config.acco is not None, live_recovery_rank_src=config.ckpt.live_recovery_rank_src ) mp_policy = MixedPrecisionPolicy( @@ -165,6 +168,12 @@ def train(config: Config): # Setup optimizers inner_optimizer = get_optimizer(model.parameters(), config.optim.optim) + if config.acco is not None: + first_step = True + reduce_work = [] + theta_t = [p.detach().clone() for p in model.parameters() if p.requires_grad] + if config.acco.theta_t_device is not None: + theta_t = [p.to(config.acco.theta_t_device) for p in theta_t] diloco = Diloco(config.diloco, model, elastic_device_mesh) if config.diloco is not None else None @@ -331,10 +340,88 @@ def train(config: Config): if config.optim.z_loss: dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg) - torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) - inner_optimizer.step() - scheduler.step() - inner_optimizer.zero_grad() + if config.acco is not None: + # TODO: This is wrong, we overwrite g_tilde before we use it in the update + g_tilde = [p.grad.detach().cpu() for p in model.parameters() if p.requires_grad] + for work in reduce_work: + work.wait() + #reduce_work = [elastic_device_mesh.global_pg.allreduce([_g_tilde], op=dist.ReduceOp.SUM) for _g_tilde in g_tilde] + #reduce_work = [dist.all_reduce(_g_tilde, dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True) for _g_tilde in g_tilde] + #a = torch.randn(10, device="cpu") + #work = dist.all_reduce(a, op=dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True) + #work.wait() + + if not first_step: + # Copy in theta_t and consume g_t + for opt_param, cpu_param, _g_t, _g_tilde in zip(model.parameters(), theta_t, g_t, g_tilde): + opt_param.data.copy_(cpu_param.data, non_blocking=True) + opt_param.grad.copy_(_g_t + _g_tilde, non_blocking=True) + opt_param.grad /= batch_size * elastic_device_mesh.global_pg.size() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + inner_optimizer.step() + scheduler.step() + inner_optimizer.zero_grad() + # Update theta_t + for param, cpu_param in zip(model.parameters(), theta_t): + cpu_param.data.copy_(param.data, non_blocking=True) + first_step = False + + for _g_tilde in g_tilde: + work = dist.all_reduce(_g_tilde.to_local(), dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True) + work.wait() + + # Stage 2: Compute g_t and theta_tilde + for grad_acc_step in range(gradient_accumulation_steps): + is_accumulating = grad_acc_step < gradient_accumulation_steps - 1 + # no sync if we are accumulating gradients + model.set_requires_gradient_sync(not is_accumulating) + + batch = next(train_dataloader_iterator) + input_ids = batch["input_ids"].to("cuda") + labels = batch["labels"].to("cuda") + if config.train.sequence_packing: + seqlens = [seqlen.to("cuda") for seqlen in batch["seqlens"]] + block_mask = create_block_mask_from_seqlens(seqlens) if seqlens is not None else None + else: + block_mask = None + + logits = model(tokens=input_ids, block_mask=block_mask).contiguous() + flatten_logits = rearrange(logits, "b seq vocab -> (b seq) vocab") + flatten_labels = rearrange(labels, "b seq -> (b seq)") + + if config.optim.z_loss: + ce_loss, z_loss = cross_entropy_max_z_loss( + flatten_logits, flatten_labels, config.optim.z_loss_weight + ) + ce_loss /= gradient_accumulation_steps + z_loss /= gradient_accumulation_steps + + del logits + loss = ce_loss + z_loss + loss.backward() + + else: + loss = F.cross_entropy(flatten_logits, flatten_labels) / gradient_accumulation_steps + del logits + loss.backward() + + if config.optim.z_loss: + loss_batch += ce_loss.clone().detach() + z_loss_batch += z_loss.clone().detach() + else: + loss_batch += loss.clone().detach() + + g_t = [p.grad.detach().cpu() for p in model.parameters() if p.requires_grad] + for work in reduce_work: + work.wait() + #reduce_work = [elastic_device_mesh.global_pg.allreduce([_g_t], op=dist.ReduceOp.SUM) for _g_t in g_t] + reduce_work = [dist.all_reduce(_g_t.to_local(), dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True) for _g_t in g_t] + + for opt_param, _g_tilde in zip(model.parameters(), g_tilde): + opt_param.grad.copy_(_g_tilde, non_blocking=True) + opt_param.grad /= (batch_size // 2 * elastic_device_mesh.global_pg.size()) + inner_optimizer.step() + inner_optimizer.zero_grad() # logging training_progress.step += 1