Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add profiling and chunked cross-entropy #192

Merged
merged 12 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,6 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# Aider
.aider*
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ GLOO_SOCKET_IFNAME=lo GLOBAL_ADDR=localhost GLOBAL_RANK=0 GLOBAL_UNIQUE_ID=0 GLO
To test DiLoCo locally you can use the helper script `scripts/simulate_multi_node_diloco.sh`

```bash
# Using 2 GPUs with 2 simulated DiLoCo workers
ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node_diloco.sh 1 2 src/zeroband/train.py @configs/debug/diloco.toml
# Using 4 GPUs (2 diloco workers, each across 2 GPUs)
ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node_diloco.sh 2 2 src/zeroband/train.py @configs/debug/diloco.toml

# Using 2 GPUs with 1 simulated DiLoCo worker
# Using 2 GPUs (2 diloco workers, each on a single GPU)
ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node_diloco.sh 2 1 src/zeroband/train.py @configs/debug/diloco.toml
```

Expand Down Expand Up @@ -140,14 +140,14 @@ uv run accelerate launch -m lm_eval --model hf --model_args pretrained=CONVERTED
### Elastic Device Mesh Configuration
| Environment Variable | Description | Default Value |
|-----------------------|--------------------------------------------------|---------------|
| `ZERO_BAND_LOG_LEVEL` | Enable debug mode for loge | `False` |
| `ZERO_BAND_LOG_LEVEL` | Enable debug log lines | `False` |
| `ZERO_BAND_GLOBAL_STORE_TIMEOUT_SECONDS` | Number of seconds before the global store operations timeout | `300` |
| `ZERO_BAND_GLOBAL_PG_TIMEOUT_SECONDS` | Number of seconds before the global process group operations timeout | `600` |
| `ZERO_BAND_GLOBAL_STORE_POLLING_INTERVAL_SECONDS` | Number of seconds between polls to the store when waiting for values | `0.1` |
| `ZERO_BAND_EDM_HEARTBEAT_INTERVAL_SECONDS` | Interval in seconds between heartbeats | `2` |
| `ZERO_BAND_EDM_HEARTBEAT_TIMEOUT_SECONDS` | Time in seconds after which a node is considered dead if no heartbeat is received | `10` |
| `ZERO_BAND_LIVE_RECO_PORT` | Port number for the live recovery server | random |
| `ZERO_BAND_LIVE_RECO_ADDR` | IP Address for the live recovery server | `localhost` |
| `ZERO_BAND_LIVE_RECO_PORT` | Port number for the live recovery server | random |
| `ZERO_BAND_LIVE_RECO_ADDR` | IP Address for the live recovery server | `localhost` |

## Troubleshooting

Expand Down
7 changes: 7 additions & 0 deletions configs/10B/H100_devel.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
name_model = "10B"
project = "debug_I2_zero_band"
metric_logger_type = "dummy"

[train]
micro_bs = 1
ac_ckpt = true
torch_profiler = false

[train.memory_profiler]
freq = 1
snapshot_dir = "logs/"

[optim]
sched_type = "wsd-sqrt"
batch_size = 128 #1M tokens bs
warmup_steps = 0
total_steps = 1
lr = 7.5e-5
num_chunks = 8

adam_betas1 = 0.9
adam_betas2 = 0.95
Expand Down
67 changes: 0 additions & 67 deletions scripts/simulate_multi_node.sh

This file was deleted.

15 changes: 8 additions & 7 deletions scripts/simulate_multi_node_diloco.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,20 @@ child_pids=()
cleanup() {
echo "Cleaning up child processes..."
local killed=0

# First kill the main processes
for pid in "${child_pids[@]}"; do
if kill -TERM "$pid" 2>/dev/null; then
((killed++))
fi
done

# Kill the tail process if it exists
if [ -n "$tail_pid" ]; then
kill -TERM "$tail_pid" 2>/dev/null
((killed++))
fi

wait
echo "All child processes terminated. Killed $killed processes."
exit
Expand All @@ -51,9 +51,10 @@ if [ "$#" -lt 3 ]; then
fi


N=$1 # Set N from the first argument
NUM_GPU=$2
shift 2 # Remove the first three arguments so $@ contains only additional Python arguments
N=$1 # The number of processes
NUM_GPU=$2 # The number of GPUs used by each process
# Remove the first three arguments so $@ contains only additional Python arguments
shift 2

# Register the cleanup function to be called on SIGINT (Ctrl+C)
trap cleanup SIGINT
Expand Down Expand Up @@ -86,4 +87,4 @@ done
# Once main processes are done, kill the tail process
if [ -n "$tail_pid" ]; then
kill -TERM "$tail_pid"
fi
fi
2 changes: 1 addition & 1 deletion src/zeroband/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def __init__(
scheduler: LambdaLR,
dataloader: StatefulDataLoader,
training_progress: TrainingProgress,
data_rank: int,
data_rank: int | None,
diloco_offloaded_param_list: list[nn.Parameter] | None,
diloco_offloaded_optimizer: Optimizer | None,
):
Expand Down
3 changes: 3 additions & 0 deletions src/zeroband/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class OptimConfig(BaseConfig):

z_loss: bool = False
z_loss_weight: float = 2e-4
num_chunks: int | None = None


class MemoryProfilerConfig(BaseConfig):
Expand All @@ -45,6 +46,8 @@ class TrainConfig(BaseConfig):

memory_profiler: MemoryProfilerConfig | None = None

torch_profiler: bool = False

sequence_packing: bool = True

attn_fn: AttnFnType = "flex"
Expand Down
3 changes: 1 addition & 2 deletions src/zeroband/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from zeroband.utils.logging import get_logger

import torch
from torch.utils.data import DataLoader
from torch.utils.data import IterableDataset, Dataset
from torchdata.stateful_dataloader import StatefulDataLoader
from torch.distributed.checkpoint.stateful import Stateful
Expand Down Expand Up @@ -300,7 +299,7 @@ def get_dataloader(
rank: int,
batch_size: int,
data_config: DataConfig,
) -> DataLoader:
) -> StatefulDataLoader:
if data_config.fake:
train_dataset = FakeTokenizedDataset(data_config.seq_length, TEST_VOCAB_SIZE)
else:
Expand Down
63 changes: 47 additions & 16 deletions src/zeroband/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,63 @@
import torch
import torch.nn.functional as F

def compute_cross_entropy_loss(
logits: Tensor,
labels: Tensor,
z_weight: float | None = None,
num_chunks: int | None = None,
ignore_index: int = -100,
) -> tuple[Tensor, Tensor | None]:
"""
Compute cross entropy loss in fp32, optionally chunked, and optionally with max z loss.

@torch.compile
def cross_entropy_max_z_loss(
logits: Tensor,
targets: Tensor,
z_loss_weight: float,
ignore_index: int = -100,
) -> Tensor:
"""MaxZLoss.
Do not torch compile this function if you set num_chunks >= 1. It will unroll the chunking loop, thus removing the benefit of chunking.

from the baichuan2 paper: https://arxiv.org/abs/2309.10305
Max z loss is from the baichuan2 paper: https://arxiv.org/abs/2309.10305

.. math::
z_{loss} = weight z^{2}

where z is the max logit
"""

logits = logits.reshape(-1, logits.size(-1))
targets = targets.reshape(-1)
num_elements = (labels != ignore_index).sum().float()

if num_chunks is not None and not num_chunks <= 1:
l_labels: list[Tensor] = [target_chunk.reshape(-1) for target_chunk in labels.chunk(num_chunks, dim=0)]
l_logits: list[Tensor] = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits.reshape(-1, logits.size(-1)).chunk(num_chunks, dim=0)]
else:
l_labels: list[Tensor] = [labels.reshape(-1)]
l_logits: list[Tensor] = [logits.reshape(-1, logits.size(-1))]

loss = 0.0
ce_loss = None if z_weight is None else 0.0
for logits_chunk, labels_chunk in zip(l_logits, l_labels):
if z_weight is None:
loss += _upcast_cross_entropy(logits_chunk, labels_chunk, ignore_index=ignore_index)
else:
ce, z = _upcast_cross_entropy_max_z(logits_chunk, labels_chunk, z_weight, ignore_index=ignore_index)
loss += ce
ce_loss += z

return (loss / num_elements), (None if ce_loss is None else ce_loss / num_elements)

loss = F.cross_entropy(logits, targets, ignore_index=ignore_index)

# Compile the upcast into the CE calculation
@torch.compile
def _upcast_cross_entropy(logit_chunk, label_chunk, ignore_index) -> Tensor:
return F.cross_entropy(logit_chunk.float(), label_chunk, ignore_index=ignore_index, reduction="sum")


@torch.compile
def _upcast_cross_entropy_max_z(
logits: Tensor,
targets: Tensor,
z_loss_weight: float,
ignore_index: int = -100,
) -> tuple[Tensor, Tensor]:
# max is not differentiable. But here we just pick the indices of the max value, so it's fine for backpropagation.
loss = F.cross_entropy(logits.float(), targets, ignore_index=ignore_index, reduction="sum")
max_logits = logits.max(dim=-1)[0]
max_logits = max_logits.where(targets != ignore_index, 0)
# max is not differentiable. But here we just pick the indices of the max
# value, so it's fine for backpropagation.

z_loss = z_loss_weight * max_logits.pow(2).mean()
return loss, z_loss
4 changes: 2 additions & 2 deletions src/zeroband/optimizers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal, TypeAlias
from typing import Iterable, Literal, TypeAlias
from pydantic_config import BaseConfig
import torch
from distributed_shampoo import (
Expand Down Expand Up @@ -29,7 +29,7 @@ class SoapConfig(BaseConfig):
OptimizersConfig: TypeAlias = AdamConfig | SoapConfig


def get_optimizer(params: list[torch.nn.Parameter], config: OptimizersConfig) -> torch.optim.Optimizer:
def get_optimizer(params: Iterable[torch.nn.Parameter], config: OptimizersConfig) -> torch.optim.Optimizer:
if isinstance(config, AdamConfig):
return torch.optim.AdamW(
params,
Expand Down
Loading
Loading