Skip to content

Commit

Permalink
really ugly implementation that doesnt explode
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackmin801 committed Jan 13, 2025
1 parent 029673e commit 7ff29a1
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/zeroband/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from zeroband.checkpoint import CkptConfig
from zeroband.data import DataConfig
from zeroband.diloco import DilocoConfig
from zeroband.global_ddp import GlobalDDPConfig
from zeroband.models.llama.model import AttnFnType
from zeroband.optimizers import OptimizersConfig, AdamConfig
from zeroband.dpu import ACCOConfig


class OptimConfig(BaseConfig):
Expand Down Expand Up @@ -69,7 +69,7 @@ class Config(BaseConfig):

# sub config
diloco: DilocoConfig | None = None
global_ddp: GlobalDDPConfig | None = None
acco: ACCOConfig | None = None
data: DataConfig = DataConfig()
optim: OptimConfig = OptimConfig()
train: TrainConfig
Expand Down
5 changes: 5 additions & 0 deletions src/zeroband/dpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from typing import Optional
from pydantic_config import BaseConfig

class ACCOConfig(BaseConfig):
theta_t_device: Optional[str] = None
97 changes: 92 additions & 5 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def train(config: Config):

assert batch_size % config.train.micro_bs == 0
gradient_accumulation_steps = batch_size // config.train.micro_bs
if config.acco:
assert gradient_accumulation_steps % 2 == 0, "ACCO requires gradient accumulation steps to be even"
gradient_accumulation_steps //= 2

if config.ckpt is not None and config.ckpt.interval is not None and config.diloco is not None:
assert (
Expand Down Expand Up @@ -137,7 +140,7 @@ def train(config: Config):
apply_ac_ckpt(model, num)

elastic_device_mesh = ElasticDeviceMesh(
enable=config.diloco is not None, live_recovery_rank_src=config.ckpt.live_recovery_rank_src
enable=config.diloco is not None or config.acco is not None, live_recovery_rank_src=config.ckpt.live_recovery_rank_src
)

mp_policy = MixedPrecisionPolicy(
Expand Down Expand Up @@ -165,6 +168,12 @@ def train(config: Config):

# Setup optimizers
inner_optimizer = get_optimizer(model.parameters(), config.optim.optim)
if config.acco is not None:
first_step = True
reduce_work = []
theta_t = [p.detach().clone() for p in model.parameters() if p.requires_grad]
if config.acco.theta_t_device is not None:
theta_t = [p.to(config.acco.theta_t_device) for p in theta_t]

diloco = Diloco(config.diloco, model, elastic_device_mesh) if config.diloco is not None else None

Expand Down Expand Up @@ -331,10 +340,88 @@ def train(config: Config):
if config.optim.z_loss:
dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg)

torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
inner_optimizer.step()
scheduler.step()
inner_optimizer.zero_grad()
if config.acco is not None:
# TODO: This is wrong, we overwrite g_tilde before we use it in the update
g_tilde = [p.grad.detach().cpu() for p in model.parameters() if p.requires_grad]
for work in reduce_work:
work.wait()
#reduce_work = [elastic_device_mesh.global_pg.allreduce([_g_tilde], op=dist.ReduceOp.SUM) for _g_tilde in g_tilde]
#reduce_work = [dist.all_reduce(_g_tilde, dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True) for _g_tilde in g_tilde]
#a = torch.randn(10, device="cpu")
#work = dist.all_reduce(a, op=dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True)
#work.wait()

if not first_step:
# Copy in theta_t and consume g_t
for opt_param, cpu_param, _g_t, _g_tilde in zip(model.parameters(), theta_t, g_t, g_tilde):
opt_param.data.copy_(cpu_param.data, non_blocking=True)
opt_param.grad.copy_(_g_t + _g_tilde, non_blocking=True)
opt_param.grad /= batch_size * elastic_device_mesh.global_pg.size()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
inner_optimizer.step()
scheduler.step()
inner_optimizer.zero_grad()
# Update theta_t
for param, cpu_param in zip(model.parameters(), theta_t):
cpu_param.data.copy_(param.data, non_blocking=True)
first_step = False

for _g_tilde in g_tilde:
work = dist.all_reduce(_g_tilde.to_local(), dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True)
work.wait()

# Stage 2: Compute g_t and theta_tilde
for grad_acc_step in range(gradient_accumulation_steps):
is_accumulating = grad_acc_step < gradient_accumulation_steps - 1
# no sync if we are accumulating gradients
model.set_requires_gradient_sync(not is_accumulating)

batch = next(train_dataloader_iterator)
input_ids = batch["input_ids"].to("cuda")
labels = batch["labels"].to("cuda")
if config.train.sequence_packing:
seqlens = [seqlen.to("cuda") for seqlen in batch["seqlens"]]
block_mask = create_block_mask_from_seqlens(seqlens) if seqlens is not None else None
else:
block_mask = None

logits = model(tokens=input_ids, block_mask=block_mask).contiguous()
flatten_logits = rearrange(logits, "b seq vocab -> (b seq) vocab")
flatten_labels = rearrange(labels, "b seq -> (b seq)")

if config.optim.z_loss:
ce_loss, z_loss = cross_entropy_max_z_loss(
flatten_logits, flatten_labels, config.optim.z_loss_weight
)
ce_loss /= gradient_accumulation_steps
z_loss /= gradient_accumulation_steps

del logits
loss = ce_loss + z_loss
loss.backward()

else:
loss = F.cross_entropy(flatten_logits, flatten_labels) / gradient_accumulation_steps
del logits
loss.backward()

if config.optim.z_loss:
loss_batch += ce_loss.clone().detach()
z_loss_batch += z_loss.clone().detach()
else:
loss_batch += loss.clone().detach()

g_t = [p.grad.detach().cpu() for p in model.parameters() if p.requires_grad]
for work in reduce_work:
work.wait()
#reduce_work = [elastic_device_mesh.global_pg.allreduce([_g_t], op=dist.ReduceOp.SUM) for _g_t in g_t]
reduce_work = [dist.all_reduce(_g_t.to_local(), dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True) for _g_t in g_t]

for opt_param, _g_tilde in zip(model.parameters(), g_tilde):
opt_param.grad.copy_(_g_tilde, non_blocking=True)
opt_param.grad /= (batch_size // 2 * elastic_device_mesh.global_pg.size())
inner_optimizer.step()
inner_optimizer.zero_grad()

# logging
training_progress.step += 1
Expand Down

0 comments on commit 7ff29a1

Please sign in to comment.