diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 310e7160..00000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "third_party/gloo"] - path = third_party/gloo - url = https://github.com/facebookincubator/gloo.git diff --git a/configs/debug/diloco.toml b/configs/debug/diloco.toml index c98e4603..9f36f14a 100644 --- a/configs/debug/diloco.toml +++ b/configs/debug/diloco.toml @@ -8,8 +8,8 @@ micro_bs = 8 [optim] batch_size = 16 -warmup_steps = 10 -total_steps = 4 +warmup_steps = 100 +total_steps = 1500 [data] fake = true diff --git a/pyproject.toml b/pyproject.toml index 1715a410..c7e20f04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "pyarrow", "toposolve", "psutil", + "pccl @ git+ssh://git@github.com/PrimeIntellect-ai/pccl.git@16110e15#egg=pccl&subdirectory=bindings/python", #todo move to https once open source ] [project.optional-dependencies] diff --git a/scripts/simulate_multi_node_diloco.sh b/scripts/simulate_multi_node_diloco.sh index db2e2a17..5eb4c24c 100755 --- a/scripts/simulate_multi_node_diloco.sh +++ b/scripts/simulate_multi_node_diloco.sh @@ -1,23 +1,41 @@ #!/bin/bash # -# simulate multi nodes on one gpu. start N torchrun on X gpu locally. -# example how to run ./scripts/simulate_multi_node.sh 2 1 src/zeroband/train.py @configs/debug/normal.toml +# Simulate multi-node on a single GPU or multiple GPUs. +# Start N torchrun instances on X GPUs locally. +# Example usage: +# ./scripts/simulate_multi_node.sh 2 1 src/zeroband/train.py @configs/debug/normal.toml + +# Function to get the total number of available GPUs +get_total_gpus() { + nvidia-smi --query-gpu=name --format=csv,noheader | wc -l +} # Function to get CUDA devices based on the number of GPUs and index -function get_cuda_devices() { +get_cuda_devices() { local num_gpu=$1 local index=$2 local start_gpu=$((num_gpu * index)) local end_gpu=$((start_gpu + num_gpu - 1)) - if [ "$num_gpu" -eq 1 ]; then - echo $start_gpu + if [ "$TOTAL_GPU" -eq 1 ]; then + echo "0" + elif [ "$num_gpu" -eq 1 ]; then + echo "$start_gpu" else - echo $(seq -s ',' $start_gpu $end_gpu) + echo "$(seq -s ',' $start_gpu $end_gpu)" fi } +# Function to find an available port +find_available_port() { + local port=$1 + while ss -tuln | grep -q ":$port "; do + port=$((port + 1)) + done + echo $port +} + # Array to store PIDs of child processes child_pids=() @@ -35,37 +53,67 @@ cleanup() { exit } -# Check if at least three arguments were passed +# Register the cleanup function to be called on SIGINT (Ctrl+C) and SIGTERM +trap cleanup SIGINT SIGTERM + if [ "$#" -lt 3 ]; then - echo "Usage: $0 [additional_python_args]" + echo "Usage: $0 [additional_python_args...]" + echo "Example: $0 2 1 src/zeroband/train.py @configs/debug/normal.toml" exit 1 fi +N=$1 # Number of ranks/nodes +NUM_GPU=$2 # Number of GPUs per node +shift 2 # Shift the first two arguments so that $@ contains only additional Python arguments -N=$1 # Set N from the first argument -NUM_GPU=$2 -shift 2 # Remove the first three arguments so $@ contains only additional Python arguments - -# Register the cleanup function to be called on SIGINT (Ctrl+C) -trap cleanup SIGINT +TOTAL_GPU=$(get_total_gpus) +if [ "$NUM_GPU" -gt "$TOTAL_GPU" ]; then + echo "Requested NUM_GPU ($NUM_GPU) exceeds the total available GPUs ($TOTAL_GPU)." + echo "Setting NUM_GPU to $TOTAL_GPU." + NUM_GPU=$TOTAL_GPU +fi mkdir -p logs export GLOBAL_ADDR=localhost export GLOBAL_PORT=${GLOBAL_PORT:-5565} export GLOBAL_WORLD_SIZE=$N -export BASE_PORT=${BASE_PORT:-10001} -export GLOO_SOCKET_IFNAME=lo -for i in $(seq 0 $(($N - 1 ))) -do - > logs/log$i.log - WANDB_MODE=$([ $i -eq 0 ] && echo "online" || echo "online") GLOBAL_UNIQUE_ID=$i GLOBAL_RANK=$i CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank 0 --rdzv-endpoint localhost:$((BASE_PORT + $i)) --nnodes=1 $@ --data.data_rank $i --data.data_world_size $N > logs/log$i.log 2>&1 & +BASE_PORT=${BASE_PORT:-10001} + +for i in $(seq 0 $((N - 1))); do + LOG_FILE="logs/log$i.log" + > "$LOG_FILE" + + CUDA_DEVICES=$(get_cuda_devices "$NUM_GPU" "$i") + + # Find an available port + PORT=$(find_available_port $((BASE_PORT + i))) + + echo "Starting rank $i with CUDA_VISIBLE_DEVICES=$CUDA_DEVICES on port $PORT" + + WANDB_MODE=$([ "$i" -eq 0 ] && echo "online" || echo "online") \ + GLOBAL_UNIQUE_ID=$i \ + GLOBAL_RANK=$i \ + CUDA_VISIBLE_DEVICES="$CUDA_DEVICES" \ + torchrun --nproc_per_node="$NUM_GPU" \ + --node_rank=0 \ + --rdzv_endpoint=localhost:$PORT \ + --rdzv_id=simulate_multi_node \ + --rdzv_backend=c10d \ + --nnodes=1 \ + "$@" \ + --data.data_rank "$i" \ + --data.data_world_size "$N" \ + > "$LOG_FILE" 2>&1 & + child_pids+=($!) done -tail -f logs/log0.log & -child_pids+=($!) +if [ "$TOTAL_GPU" -ge 1 ]; then + tail -f "logs/log0.log" & + child_pids+=($!) +fi -wait +wait \ No newline at end of file diff --git a/src/zeroband/comms.py b/src/zeroband/comms.py index b3de9233..5a242c8b 100644 --- a/src/zeroband/comms.py +++ b/src/zeroband/comms.py @@ -1,609 +1,65 @@ -import sys +import ipaddress import os -import time -import subprocess -from torch.distributed.device_mesh import init_device_mesh -from zeroband.utils.world_info import get_world_info -from zeroband.utils.logging import get_logger -import torch.distributed as dist -from datetime import timedelta -from typing import List, Tuple, Optional -from torch.testing._internal.distributed.fake_pg import FakeProcessGroup -import multiprocessing as mp -from uuid import uuid4 -import toposolve -from zeroband.utils.ip import parse_iperf_output +import threading -TCPSTORE_TIMEOUT = timedelta(seconds=int(os.getenv("ZERO_BAND_GLOBAL_STORE_TIMEOUT_SECONDS", "300"))) -TCPSTORE_POLLING_INTERVAL = float(os.getenv("ZERO_BAND_GLOBAL_STORE_POLLING_INTERVAL_SECONDS", "0.1")) -GLOBAL_PG_TIMEOUT = timedelta(seconds=int(os.getenv("ZERO_BAND_GLOBAL_PG_TIMEOUT_SECONDS", "600"))) -MAX_JOINERS = 100 # Maximum number of nodes that can join in a single reinit -HEARTBEAT_INTERVAL = int( - os.getenv("ZERO_BAND_EDM_HEARTBEAT_INTERVAL_SECONDS", "2") -) # Interval in seconds between heartbeats -HEARTBEAT_TIMEOUT = int( - os.getenv("ZERO_BAND_EDM_HEARTBEAT_TIMEOUT_SECONDS", "10") -) # Time in seconds after which a node is considered dead if no heartbeat is received -IPERF_PORT = int(os.getenv("ZERO_BAND_IPERF_PORT", "10101")) -IPERF_IFNAME = os.getenv("GLOO_SOCKET_IFNAME", "eth0") -BENCH_TENSOR_SIZE = 1_000_000 +import pccl +PCCL_INITIALIZED = False -class ElasticDeviceMesh: - """A class to manage the process groups for elastic training without restarts. - The way it works is rank 0 coordinates the joining and leaving of nodes. - Rank 0 manages the status to coordinate the creation and recreation of the process groups. - When a node wants to join, rank 0 will setup the store so that all nodes know the new world size and their respective ranks. +class PcclCommunicator: - Store keys used: - - status: "init", "running", "reinit" - - world_size: The current world size - - mesh_count: The version of the mesh - - rank_{uuid}: The rank of the node with the given uuid - - joiner_{i}: The uuid of the ith joiner. Its a KV implmentation of a queue. - """ + def __init__(self, master_ip: str, master_port: int): + global PCCL_INITIALIZED + if not PCCL_INITIALIZED: + pccl.pccl_init() + PCCL_INITIALIZED = True - local_pg: dist.ProcessGroup - global_pg: dist.ProcessGroup - - def __init__( - self, backend: str = "cpu:gloo,cuda:nccl", enable: bool = True, live_recovery_rank_src: int | None = None - ): - self._logger = get_logger() - self.world_info = get_world_info() - self.live_recovery_rank_src = live_recovery_rank_src - - # Initialize global process group - self.global_pg = FakeProcessGroup(self.world_info.rank, 1) - - self.enable = enable - if enable: - self._init_global_pg() - - # Initialize local process group - dist.init_process_group(backend=backend) - self.mesh = init_device_mesh( - "cuda", - (self.world_info.nnodes, self.world_info.local_world_size), - mesh_dim_names=("internode", "intranode"), - ) - self.local_pg = self.mesh.get_group("intranode") - - # Start heartbeat - - self.cuda_local_mesh = init_device_mesh("cuda", mesh_shape=(self.local_pg.size(),)) - self.cpu_local_mesh = init_device_mesh("cpu", mesh_shape=(self.local_pg.size(),)) - - # Logging - if self.enable: - self._optimize_ring_ranks() - if self.live_recovery_rank_src is not None: - self.live_recovery.ask_for_live_ckpt(self.live_recovery_rank_src) - self.global_pg.barrier().wait() - - self._logger.info(f"global_pg size : {self.global_pg.size()}, local_pg size: {self.local_pg.size()}") - - def __del__(self): - self._stop_heartbeat() - dist.destroy_process_group() - - def _init_global_store(self): - self._logger.info( - f"[{self.world_info.global_unique_id}](Leader: {self._global_leader}) TCPStore init: Connecting via {self.world_info.global_addr}:{self.world_info.global_port + self.world_info.rank}" - ) - self.global_store = dist.TCPStore( - host_name=self.world_info.global_addr, - port=self.world_info.global_port + self.world_info.rank, - timeout=TCPSTORE_TIMEOUT, - is_master=self._global_leader, - ) - self.god_store = dist.TCPStore( - host_name=self.world_info.global_addr, - port=self.world_info.global_port, - timeout=TCPSTORE_TIMEOUT, - is_master=False, - ) - - def _init_global_store_values(self): - """Initialize the global store with mesh_count, joiner_0, and status. Also sets the global status.""" - self._logger.debug("Initializing global store values") - self.global_store.set(f"gid_{self.world_info.global_rank}", self.world_info.global_unique_id) - self.global_store.set(f"rank_{self.world_info.global_unique_id}", str(self.world_info.global_rank)) - if self._global_leader: - self.global_store.set("mesh_count", "0") - self.global_store.set("world_size", str(self.world_info.global_world_size)) - self.global_store.set("joiner_0", "null") - for i in range(self.world_info.global_world_size): - self.global_store.set(f"barrier_{i}", "null") - self._global_ids = [ - self.global_store.get(f"gid_{i}").decode("utf-8") for i in range(self.world_info.global_world_size) - ] - for i in self._global_ids: - for j in self._global_ids: - self.global_store.set(f"ping_{i}_{j}", "1000_000_000") - self.global_store.set("status", "init") - self.global_status = "init" + if os.getenv("NO_IMPLICIT_MASTER") is None: + self.master_handle = pccl.pccl_create_master(use_ipv4=True) + self.master_thread = threading.Thread(target=self.run_master_blocking) + self.master_thread.start() else: - self.global_status = self._wait_for_status() - self._global_ids = [ - self.global_store.get(f"gid_{i}").decode("utf-8") for i in range(self.world_info.global_world_size) - ] - - def _create_global_pg(self): - # Delete the old global_pg - if hasattr(self, "global_pg"): - if sys.getrefcount(self.global_pg) > 2: - self._logger.warning( - f"Global PG refcount was {sys.getrefcount(self.global_pg)} when 2 is expected during deletion. This may cause a memory leak." - ) - del self.global_pg # TODO(jackmin): Where do we catch errors in teardown? - self._logger.info("Destroyed process group") + self.master_handle = 0 + self.master_thread = None - # Get new global rank and world size - self.world_info.global_rank = int( - self.global_store.get(f"rank_{self.world_info.global_unique_id}").decode("utf-8") - ) - self.world_info.global_world_size = int(self.global_store.get("world_size").decode("utf-8")) - self.mesh_count = int(self.global_store.get("mesh_count").decode("utf-8")) - self._logger.debug( - f"New global rank: {self.world_info.global_rank}, New global world size: {self.world_info.global_world_size} New mesh count: {self.mesh_count}" - ) + # parse ip + ip = ipaddress.ip_address(master_ip) - # Create prefix store - prefix_store = dist.PrefixStore(f"mesh_{self.mesh_count}", self.global_store) - self._logger.debug(f"Created prefix store with mesh_{self.mesh_count}") - - # Create process group - self._logger.debug( - f"Creating global pg with {self.world_info.global_world_size} rank {self.world_info.global_rank}" - ) - self.global_pg = dist.ProcessGroupGloo( - prefix_store, self.world_info.global_rank, self.world_info.global_world_size, GLOBAL_PG_TIMEOUT - ) - self._logger.debug("Global pg created with %d peers. Timeout of %s", self.global_pg.size(), GLOBAL_PG_TIMEOUT) - - def _optimize_ring_ranks(self): - self._global_ids = [ - self.global_store.get(f"gid_{i}").decode("utf-8") for i in range(self.world_info.global_world_size) - ] - if self.world_info.local_rank == 0: - self._logger.debug("Measuring bandwidths") - self._measure_connectivity() - self._logger.debug("Measuring bandwidths done") - - self.local_pg.barrier().wait() - self.global_pg.barrier().wait() - - if self._global_leader: - self._logger.debug("Calculating TSP") - pings = self.get_pings() - min_dist, path = toposolve.TSPSolver().solve_tsp(pings) - self._logger.debug(f"Min distance: {min_dist}") - self._logger.debug(f"Path: {path}") - new_gids = [self._global_ids[i] for i in path[:-1]] - assert set(new_gids) == set(self._global_ids) - - for i, gid in enumerate(new_gids): - self.global_store.set(f"rank_{gid}", str(i)) - self.global_store.set(f"gid_{i}", gid) - self.global_store.set("mesh_count", str(self.mesh_count + 1)) - - self.local_pg.barrier().wait() - self.global_pg.barrier().wait() - - self._global_ids = [ - self.global_store.get(f"gid_{i}").decode("utf-8") for i in range(self.world_info.global_world_size) - ] - self._create_global_pg() - - def _queue_join(self): - """Queue a node to join the mesh.""" - for i in range(MAX_JOINERS): - joiner_id = self.global_store.get(f"joiner_{i}").decode("utf-8") - if joiner_id == "null": - self.global_store.set(f"joiner_{i}", self.world_info.global_unique_id) - self.global_store.set(f"joiner_{i + 1}", "null") - break + if isinstance(ip, ipaddress.IPv4Address): + is_ipv4 = True + elif isinstance(ip, ipaddress.IPv6Address): + is_ipv4 = False else: - raise RuntimeError("Too many joiners") - - def _get_joiners(self) -> Tuple[List[str], List[str]]: - joiners = [] - for i in range(MAX_JOINERS): - joiner_id = self.global_store.get(f"joiner_{i}").decode("utf-8") - if joiner_id == "null": - break - joiners.append(joiner_id) - return joiners - - def _clear_joiners(self): - self.global_store.set("joiner_0", "null") - - def _wait_for_status(self, status: Optional[str] = None) -> str: - """Wait for status to be set in the store. - - Args: - store (dist.Store): The store to check. - status (Optional[str], optional): The status to wait for. If None, wait for any status. Defaults to None. - Returns: - status (str): The status. - """ - while True: - try: - ret = self.global_store.get("status").decode("utf-8") - if status is None or ret == status: - return ret - time.sleep(TCPSTORE_POLLING_INTERVAL) - except dist.DistStoreError as e: - if status is not None: - raise e - time.sleep(0.1) + raise RuntimeError(f"Invalid master ip: {master_ip}") - def _init_global_pg(self) -> None: - # Each rank gets its own global store with global rank 0 as the master - time_start = time.perf_counter() + self.communicator = pccl.pccl_create_communicator() + pccl.pccl_connect_master(self.communicator, + pccl.ccoip_inet_protocol_t.inetIPv4 if is_ipv4 else pccl.ccoip_inet_protocol_t.inetIPv6, + list(ip.packed), master_port) - self._global_leader = self.world_info.global_rank == 0 - self._init_global_store() - - # Initialize store values - self._init_global_store_values() - - self.live_recovery = LiveRecovery(store=self.global_store) - - if self.global_status == "running": # Join path - # Ask to join and then wait for the status to be "reinit" - self._logger.info("Waiting to join") - self._queue_join() - self._wait_for_status("reinit") - - # Create global process group - self._create_global_pg() - - # Update global store values - if self._global_leader: - self.global_store.set("status", "running") - self.global_store.set("resolved_time", uuid4().hex) - self.global_status = "running" - self._last_resolved_time = self.global_store.get("resolved_time").decode("utf-8") - - self._start_heartbeat() - - self._logger.info( - f"Elastic Device mesh init done with {self.global_pg.size()} peers in {time.perf_counter() - time_start} seconds" - ) - - if self.world_info.local_rank == 0: - self._start_iperf_server() - self._evicted_nodes = [] - - def _start_heartbeat(self): - """Start sending heartbeats to the global store in a separate process.""" - self._heartbeat_stop_event = mp.Event() - self._heartbeat_process = mp.Process(target=self._heartbeat_loop, args=(self._heartbeat_stop_event,)) - self._heartbeat_process.start() - - def _stop_heartbeat(self): - """Stop the heartbeat process.""" - self._send_deathrattle() - if hasattr(self, "_heartbeat_stop_event"): - self._heartbeat_stop_event.set() - self._heartbeat_process.join() - - def _heartbeat_loop(self, stop_event): - """Continuously send heartbeats until stopped.""" + def run_master_blocking(self): try: - while not stop_event.is_set(): - self._send_heartbeat() - time.sleep(HEARTBEAT_INTERVAL) - finally: - self._send_deathrattle() - - def _send_heartbeat(self): - """Send a heartbeat to the global store.""" - current_time = time.time() + pccl.pccl_run_master(self.master_handle) + except RuntimeError: + pccl.pccl_destroy_master(self.master_handle) + self.master_handle = 0 + + def interrupt_master(self): + if self.master_handle != 0: + pccl.pccl_interrupt_master(self.master_handle) + self.master_thread.join() + pccl.pccl_destroy_master(self.master_handle) + + def all_reduce(self, data_ptr, n_elements: int): + reduce_info = pccl.pcclReduceInfo_t() try: - self.global_store.set(f"heartbeat_{self.world_info.global_unique_id}", str(current_time)) - except Exception: - self._logger.error("Error sending heartbeat", exc_info=True) - pass - - def _send_deathrattle(self): - """Send a deathrattle to the global store.""" - if hasattr(self, "global_store"): - self.global_store.set(f"heartbeat_{self.world_info.global_unique_id}", "-100") - else: - import warnings - - warnings.warn("global_store garbage collected. Skipping deathrattle.") - - def _check_heartbeats(self) -> List[str]: - """Check heartbeats and return a list of nodes that have missed their heartbeats.""" - dead_nodes = [] - current_time = time.time() - for gid in self._global_ids: - try: - last_heartbeat = float(self.global_store.get(f"heartbeat_{gid}").decode("utf-8")) - self._logger.debug(f"Node {gid} last heartbeat: {last_heartbeat}") - if current_time - last_heartbeat > HEARTBEAT_TIMEOUT: - dead_nodes.append(gid) - self.global_store.delete_key(f"heartbeat_{gid}") - except dist.DistStoreError: - self._logger.warning(f"Node {gid} has no heartbeat") - return dead_nodes - - def _resolve_world(self, admit_joiners: bool = False) -> bool: - """Set the new world size and ranks for all nodes if there are joiners or dead nodes. Else, do nothing. - - Args: - admit_joiners (bool, optional): Whether to admit joiners. Defaults to False. - Returns: - bool: True if the world was changed, False otherwise. - """ - # Find joiners - if admit_joiners: - joiners = self._get_joiners() - else: - joiners = [] - - # Check for dead nodes - dead_nodes = self._check_heartbeats() - self._logger.debug( - "Joiners (%sadmitting): %s, Dead nodes: %s, Evicting nodes: %s", - "" if admit_joiners else "not ", - joiners, - dead_nodes, - self._evicted_nodes, - ) - dead_nodes.extend(self._evicted_nodes) - - # If no joiners or dead nodes, no resolution needed - if len(joiners) == 0 and len(dead_nodes) == 0: - return False - - # Remap live ranks to smaller world_size caused by dead nodes - leaving_nodes = set(dead_nodes) - live_ranks = [i for i in self._global_ids if i not in leaving_nodes] - for i, rank in enumerate(live_ranks): - self.global_store.set(f"rank_{rank}", str(i)) - self.global_store.set(f"gid_{i}", rank) - new_world_size = len(live_ranks) - - # Give joiners new ranks - for joiner_id in joiners: - self.global_store.set(f"rank_{joiner_id}", str(new_world_size)) - self.global_store.set(f"gid_{new_world_size}", joiner_id) - live_ranks.append(joiner_id) - new_world_size += 1 - - self._global_ids = live_ranks - for i in self._global_ids: - for j in self._global_ids: - self.global_store.set(f"ping_{i}_{j}", "1000_000_000") - for i in range(1, new_world_size): - self.global_store.set(f"barrier_{i}", "null") - # Update world_size - self.global_store.set("world_size", str(new_world_size)) - self.global_store.set("mesh_count", str(self.mesh_count + 1)) - # Set status to "reinit" - self.global_store.set("status", "reinit") - return True - - def maybe_reinit_global_pg(self, admit_joiners: bool = False) -> bool: - """Reinitialize the global_pg if there are is a state change. - - Args: - admit_joiners (bool, optional): Whether to admit joiners. Defaults to False. - Returns: - bool: True if the global_pg was reinitialized, False otherwise. - """ - if not self.enable: - # no op if disabled - return - - time_start = time.perf_counter() - self._logger.debug("[%s] Resolving world", self.world_info.global_unique_id) - if self._global_leader: - self._resolve_world(admit_joiners=admit_joiners) - self.global_store.set("resolved_time", uuid4().hex) - else: - while (ans := self.global_store.get("resolved_time").decode("utf-8")) == self._last_resolved_time: - # TODO: Have a timeout here in case the leader is dead - time.sleep(TCPSTORE_POLLING_INTERVAL) - self._last_resolved_time = ans - - self._logger.debug("World resolved in %s seconds", time.perf_counter() - time_start) - - status = self.global_store.get("status").decode("utf-8") - if status == "running": # No joiners or dead nodes - return False - - # Reinit Path - try: - self._create_global_pg() - self._optimize_ring_ranks() - self.global_pg.barrier().wait() - except Exception as e: - self._logger.error(f"Error recreating process group: {e}. Retrying...") - return self.maybe_reinit_global_pg(admit_joiners=admit_joiners) - - if self._global_leader: - self._clear_joiners() - self.global_store.set("status", "running") - - self._logger.debug("Reinitialized global_pg done in %s seconds", time.perf_counter() - time_start) - - # TODO: We need to reset the self.world_info.global_rank reference - # Somehow the reference becomes stale and the heartbeats become wrong - # This will be fixed when heartbeats become unique id dependent which never changes - self._logger.debug("Reset Heartbet") - self._stop_heartbeat() - self._start_heartbeat() - self._logger.debug("Reset Heartbeat done") - return True - - def get_global_pg(self, maybe_reinit: bool = False) -> dist.ProcessGroup: - """Get the global process group. If maybe_reinit is True, reinitialize the global process group if needed.""" - if maybe_reinit: - self.maybe_reinit_global_pg() - return self.global_pg - - def monitored_barrier(self, flag: str): - flag = str(flag) - time_start = time.perf_counter() - self._logger.debug("[%s] Monitored Barrier %s", self.world_info.global_unique_id, flag) - if self._global_leader: - self._logger.debug("Others have %d seconds to resolve", GLOBAL_PG_TIMEOUT.total_seconds()) - while not all( - self.global_store.get(f"barrier_{i}").decode("utf-8") == flag - for i in range(1, self.world_info.global_world_size) - ): - if time.perf_counter() - time_start > GLOBAL_PG_TIMEOUT.total_seconds(): - self._logger.error("Monitored barrier failed due to timeout") - self._evicted_nodes = [ - i - for i in range(1, self.world_info.global_world_size) - if self.global_store.get(f"barrier_{i}").decode("utf-8") != flag - ] - self._logger.info("Evicting nodes: %s", self._evicted_nodes) - self.global_store.set(f"barrier_{self.world_info.global_rank}", "error") - # We neeed to evict the dead node - raise RuntimeError("Monitored barrier failed due to timeout") - time.sleep(TCPSTORE_POLLING_INTERVAL) - self.global_store.set(f"barrier_{self.world_info.global_rank}", flag) - else: - self.global_store.set(f"barrier_{self.world_info.global_rank}", flag) - while (ans := self.global_store.get("barrier_0").decode("utf-8")) != flag: - if ans == "error": - raise RuntimeError("Monitored barrier failed due to error") - # TODO: Have a timeout here in case the leader is dead - time.sleep(TCPSTORE_POLLING_INTERVAL) - - self._logger.debug("Monitored barrier resolved in %s seconds", time.perf_counter() - time_start) - - def get_pings(self) -> List[List[int]]: - pings = [[1000_000_000] * self.world_info.global_world_size for _ in range(self.world_info.global_world_size)] - for i, e1 in enumerate(self._global_ids): - for j, e2 in enumerate(self._global_ids): - if i == j: - continue - pings[i][j] = int(self.god_store.get(f"ping_{e1}_{e2}")) - - self._logger.debug("\n %s", format_grid(pings)) - return pings - - def _start_iperf_server(self) -> None: - """Start the iperf server process.""" - try: - from zeroband.utils.ip import get_ip_address - - iperf_addr = get_ip_address(IPERF_IFNAME) - iperf_port = IPERF_PORT + self.world_info.global_rank - cmd: List[str] = ["iperf", "-s", "-p", str(iperf_port)] - self.server_process = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - self.god_store.set(f"iperf_{self.world_info.global_unique_id}", f"{iperf_addr}:{iperf_port}") - self._logger.info(f"Started iperf server on {iperf_addr} with port {iperf_port}") - except Exception as e: - self._logger.error(f"Failed to start iperf server: {str(e)}") - raise - - def _measure_connectivity(self): - for i in self._global_ids: - if i == self.world_info.global_unique_id: - continue - target_host, target_port = self.god_store.get(f"iperf_{i}").decode("utf-8").split(":") - target_port = int(target_port) - time_taken = self.measure_bandwidth(target_host, target_port) - self.god_store.set(f"ping_{self.world_info.global_unique_id}_{i}", str(time_taken)) - - def measure_bandwidth(self, target_host: str, target_port: int) -> int: - """ - Measure bandwidth to a specific target. - - Args: - target_host: The host to measure bandwidth to - target_port: The port to measure bandwidth to - - Returns: - int: The time taken to transfer 10Tb of data in seconds - """ - try: - cmd: List[str] = [ - "iperf", - "-c", - target_host, - "-p", - str(target_port), - "-t", - "1", # 1 second test - ] - result: subprocess.CompletedProcess = subprocess.run(cmd, capture_output=True, text=True, timeout=5) - - if result.returncode != 0: - raise Exception(f"iperf error: {result.stderr}") - - time_taken: int = int(1e13 / parse_iperf_output(result.stdout)) - time_taken = min(time_taken, 1_000_000_000) - - return time_taken + pccl.pccl_allreduce(data_ptr, data_ptr, n_elements, pccl.pcclDataType_t.pcclFloat, pccl.pcclRedOp_t.pcclAvg, + 1, + self.communicator, reduce_info) except Exception as e: - self._logger.error(f"Error measuring bandwidth to {target_host}:{target_port} {str(e)}") - return int(1e9) - - -def format_grid(grid): - N = len(grid) - - # Set the main diagonal elements to 0 - for i in range(N): - grid[i][i] = 0 - - # Determine the width needed for formatting based on max possible value (99.99) and indices - cell_width = 6 - - # Create header row with column indices - header_row = " " + " | ".join(f"{j:>{cell_width-1}}" for j in range(N)) - - # Start building the formatted grid string - formatted_grid = header_row + "\n" - - for i, row in enumerate(grid): - # Format each element in the row - formatted_row = [f"{i:>2}"] # Add row index at the beginning of the row - for value in row: - # Divide by 1000 and format to 2 decimal places - formatted_value = f"{value / 1000:.2f}" - formatted_row.append(formatted_value) - - # Join the elements of the row with '|' and add it to the grid string - formatted_grid += " | ".join(formatted_row).center(cell_width * (N + 1)) + "\n" - - return formatted_grid.strip() - - -class LiveRecovery: - def __init__(self, store: dist.Store): - self.logger = get_logger() - self.world_info = get_world_info() - - self.store = dist.PrefixStore("live_recovery", store) - self.reset() - - def reset(self): - self.store.set(f"rank_{self.world_info.global_rank}", "null") - - def should_send_ckpt_to(self) -> int | None: - """use this function to check if someone is awaiting for a live ckpt""" - data = self.store.get(f"rank_{self.world_info.global_rank}").decode("utf-8") - if data == "null": - return None - try: - return int(data) - except ValueError as e: - self.logger.error(f"Error parsing live recovery data: {e}") - return None + print("all_reduce failed", e) - def ask_for_live_ckpt(self, rank: int) -> int | None: - """use this function to send a signal to a node to ask for a live ckpt""" - self.store.set(f"rank_{rank}", str(self.world_info.global_rank)) + def destroy(self): + pccl.pccl_destroy_communicator(self.communicator) + self.interrupt_master() diff --git a/src/zeroband/diloco.py b/src/zeroband/diloco.py index 2e387055..e0a027a0 100644 --- a/src/zeroband/diloco.py +++ b/src/zeroband/diloco.py @@ -3,13 +3,13 @@ from pydantic_config import BaseConfig import torch from torch import nn -from zeroband.collectives import Compression, all_reduce -from zeroband.comms import ElasticDeviceMesh +from zeroband.collectives import Compression +from zeroband.comms import PcclCommunicator from zeroband.utils.world_info import get_world_info from zeroband.utils.logging import get_logger -import torch.distributed as dist from torch.distributed._tensor.api import DTensor from functools import lru_cache +from torch.distributed.device_mesh import init_device_mesh class DilocoConfig(BaseConfig): @@ -56,22 +56,17 @@ class Diloco: """ def __init__( - self, - config: DilocoConfig, - model: nn.Module, - elastic_device_mesh: ElasticDeviceMesh, + self, + config: DilocoConfig, + model: nn.Module, + pccl_communicator: PcclCommunicator, ): self.config = config - - if config.compression == Compression.UINT8: - from zeroband.C.collectives import ring_allreduce as _ # noqa: F401 - # just force compilation - - self.elastic_device_mesh = elastic_device_mesh + self.pccl_communicator = pccl_communicator self._logger = get_logger() self.world_info = get_world_info() - + self.cpu_local_mesh = init_device_mesh("cpu", mesh_shape=(self.world_info.local_world_size,)) self._init_offloaded_optimizer(model=model) @torch.no_grad() @@ -89,14 +84,10 @@ def sync_pseudo_gradient(self, model: nn.Module, fake: bool = False, flag: str = """ _start_time = time.perf_counter() - self.elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=False) - world_size_post_init = self.elastic_device_mesh.global_pg.size() - - world_size = world_size_post_init + world_size = 1 # todo - self._logger.debug("sync pseudo gradient %s with world size %d", " fake" if fake else "", world_size) + self._logger.debug("sync pseudo gradient %s", " fake" if fake else "") - global_pg = self.elastic_device_mesh.global_pg for i in range(self.config.retry_all_reduce): for param_offloaded, param in zip(self.param_list_cpu, model.parameters()): if fake: @@ -104,19 +95,19 @@ def sync_pseudo_gradient(self, model: nn.Module, fake: bool = False, flag: str = else: param_offloaded.grad.to_local().copy_(param_offloaded.data.to_local()) param_offloaded.grad.to_local().sub_(param.data.to_local().to(param_offloaded.data.device)) + try: self.offloaded_grad_flat_tensor.div_(world_size) _collective_start_time = time.perf_counter() - self._logger.debug("Waiting on barrier") - self.elastic_device_mesh.monitored_barrier(flag) - self._logger.debug("Beginning all reduce") - # all_reduce(self.config.compression, self.offloaded_grad_flat_tensor, dist.ReduceOp.SUM, global_pg) + self._logger.debug(f"Beginning all reduce attempt {i + 1}/{self.config.retry_all_reduce}") for j, tensor_group in enumerate(self._offloaded_grad_grouped_tensor): t0 = time.perf_counter() - all_reduce(self.config.compression, tensor_group, dist.ReduceOp.SUM, global_pg) + + self.pccl_communicator.all_reduce(tensor_group.data_ptr(), tensor_group.numel()) + self._logger.debug( - f"{j}/{len(self._offloaded_grad_grouped_tensor)} all reduce bucket done in {time.perf_counter() - t0:.6f} seconds, numel: {tensor_group.numel()}" + f"{j + 1}/{len(self._offloaded_grad_grouped_tensor)} all reduce bucket done in {time.perf_counter() - t0:.6f} seconds, numel: {tensor_group.numel()}" ) self._logger.debug( @@ -124,8 +115,7 @@ def sync_pseudo_gradient(self, model: nn.Module, fake: bool = False, flag: str = ) break except Exception as e: - self._logger.error(f"Error syncing pseudo gradient: {e}, retry {i+1}/{self.config.retry_all_reduce}") - global_pg = self.elastic_device_mesh.get_global_pg(maybe_reinit=True) + self._logger.error(f"Error syncing pseudo gradient: {e}, retry {i + 1}/{self.config.retry_all_reduce}") else: self._logger.error( "Failed to sync pseudo gradient after %d retries. Resorting to calculating pseudo-gradient without reduce", @@ -181,14 +171,14 @@ def get_offloaded_param(self, model: nn.Module) -> list[nn.Parameter]: offloaded_param = nn.Parameter( DTensor.from_local( data_tensor, - device_mesh=self.elastic_device_mesh.cpu_local_mesh, + device_mesh=self.cpu_local_mesh, placements=param.data.placements, ) ) offloaded_param.grad = DTensor.from_local( grad_tensor, - device_mesh=self.elastic_device_mesh.cpu_local_mesh, + device_mesh=self.cpu_local_mesh, placements=param.data.placements, ) # here we pre-allocate the grad DTensor on cpu. diff --git a/src/zeroband/train.py b/src/zeroband/train.py index fc47503b..0b0680f1 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -17,7 +17,7 @@ import torch.distributed as dist from zeroband import utils from zeroband.diloco import Diloco, DilocoConfig -from zeroband.comms import ElasticDeviceMesh +from zeroband.comms import PcclCommunicator from zeroband.loss import cross_entropy_max_z_loss from zeroband.utils import ( @@ -38,6 +38,8 @@ from zeroband.checkpoint import CkptConfig, CkptManager, TrainingProgress from zeroband.lr_scheduler import get_scheduler +from pccl import PROTOCOL_PORT_MASTER + class OptimConfig(BaseConfig): lr: float = 4e-4 @@ -108,12 +110,14 @@ class Config(BaseConfig): monitor: MonitorConfig | None = None ckpt: CkptConfig = CkptConfig() + ccoip_master_ip: str = "127.0.0.1" + ccoip_master_port: int = PROTOCOL_PORT_MASTER @model_validator(mode="after") def ckpt_diloco_step(self): if self.ckpt is not None and self.ckpt.interval is not None and self.diloco is not None: assert ( - self.ckpt.interval % self.diloco.inner_steps == 0 + self.ckpt.interval % self.diloco.inner_steps == 0 ), "ckpt interval must be a multiple of diloco inner steps as we only save at the end of an outer step" return self @@ -134,7 +138,7 @@ def train(config: Config): if config.ckpt is not None and config.ckpt.interval is not None and config.diloco is not None: assert ( - config.ckpt.interval % config.diloco.inner_steps == 0 + config.ckpt.interval % config.diloco.inner_steps == 0 ), "ckpt interval must be a multiple of diloco inner steps as we only save at the end of an outer step" if config.type_model == "llama2": @@ -180,9 +184,9 @@ def train(config: Config): num = 1 if isinstance(config.train.ac_ckpt, bool) else config.train.ac_ckpt apply_ac_ckpt(model, num) - elastic_device_mesh = ElasticDeviceMesh( - enable=config.diloco is not None, live_recovery_rank_src=config.ckpt.live_recovery_rank_src - ) + pccl_communicator = PcclCommunicator(config.ccoip_master_ip, config.ccoip_master_port) + + dist.init_process_group() mp_policy = MixedPrecisionPolicy( param_dtype=torch.bfloat16, reduce_dtype=torch.float32 if config.train.reduce_fp32 else None @@ -196,13 +200,11 @@ def train(config: Config): fully_shard( transformer_block, mp_policy=mp_policy, - mesh=elastic_device_mesh.cuda_local_mesh, reshard_after_forward=reshard_after_forward, ) fully_shard( model, mp_policy=mp_policy, - mesh=elastic_device_mesh.cuda_local_mesh, reshard_after_forward=config.train.reshard_after_forward, ) logger.debug("model fsdped") @@ -216,7 +218,7 @@ def train(config: Config): ) if config.diloco is not None: - diloco = Diloco(config.diloco, model, elastic_device_mesh) + diloco = Diloco(config.diloco, model, pccl_communicator) scheduler = get_scheduler( sched_type=config.optim.sched_type, @@ -284,7 +286,7 @@ def train(config: Config): logger.info("starting training") - need_live_recovery = config.ckpt.live_recovery_rank_src is not None + # need_live_recovery = config.ckpt.live_recovery_rank_src is not None while True: if num_inner_steps > 1: # if we don't use diloco we don't print the outer step logs @@ -292,51 +294,51 @@ def train(config: Config): time_start_outer = time.perf_counter() - if config.diloco is not None: - # this is a patch for now to allow live recovery worker to not affect the all reduce at all - - if not need_live_recovery: - elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=True) - - maybe_dest_rank = elastic_device_mesh.live_recovery.should_send_ckpt_to() - if maybe_dest_rank is not None: - logger.info(f"Start live recovery to rank {maybe_dest_rank}") - if config.train.log_model_hash: - logger.info( - f"live recovery outer optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}" - ) - logger.info( - f"live recovery outer model hash: {get_tensor_list_signature(diloco.param_list_cpu)}" - ) - logger.info(f"inner optimizer hash: {get_optimizer_signature(inner_optimizer)}") - - ckpt_manager.send_ckpt_to_peer(elastic_device_mesh.global_pg, maybe_dest_rank, blocking=True) - - elastic_device_mesh.live_recovery.reset() - else: - ## receiving - time_start_live_recovery = time.perf_counter() - logger.info(f"Start live recovery from rank {config.ckpt.live_recovery_rank_src}") + # if config.diloco is not None: + # # this is a patch for now to allow live recovery worker to not affect the all reduce at all - ## we create grad buffer and opts stats mamnually, the value will be overwritten by the ckpt but we need the DTensor to be correctly init before loading it + # if not need_live_recovery: + # elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=True) - diloco.outer_optimizer.step() # need to step to init the DTensor stats + # maybe_dest_rank = elastic_device_mesh.live_recovery.should_send_ckpt_to() + # if maybe_dest_rank is not None: + # logger.info(f"Start live recovery to rank {maybe_dest_rank}") + # if config.train.log_model_hash: + # logger.info( + # f"live recovery outer optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}" + # ) + # logger.info( + # f"live recovery outer model hash: {get_tensor_list_signature(diloco.param_list_cpu)}" + # ) + # logger.info(f"inner optimizer hash: {get_optimizer_signature(inner_optimizer)}") - ckpt_manager.recv_ckpt_from_peer(elastic_device_mesh.global_pg) + # ckpt_manager.send_ckpt_to_peer(elastic_device_mesh.global_pg, maybe_dest_rank, blocking=True) - if config.train.log_model_hash: - logger.info( - f"live recovery outer optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}" - ) - logger.info(f"live recovery outer model hash: {get_tensor_list_signature(diloco.param_list_cpu)}") - logger.info(f"inner optimizer hash: {get_optimizer_signature(inner_optimizer)}") + # elastic_device_mesh.live_recovery.reset() + # else: + # ## receiving + # time_start_live_recovery = time.perf_counter() + # logger.info(f"Start live recovery from rank {config.ckpt.live_recovery_rank_src}") + + # ## we create grad buffer and opts stats mamnually, the value will be overwritten by the ckpt but we need the DTensor to be correctly init before loading it - need_live_recovery = False + # diloco.outer_optimizer.step() # need to step to init the DTensor stats - if config.ckpt.remote_data_load: - ckpt_manager.remote_data_load() + # ckpt_manager.recv_ckpt_from_peer(elastic_device_mesh.global_pg) - logger.info("live recovery done in %f", time.perf_counter() - time_start_live_recovery) + # if config.train.log_model_hash: + # logger.info( + # f"live recovery outer optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}" + # ) + # logger.info(f"live recovery outer model hash: {get_tensor_list_signature(diloco.param_list_cpu)}") + # logger.info(f"inner optimizer hash: {get_optimizer_signature(inner_optimizer)}") + + # need_live_recovery = False + + # if config.ckpt.remote_data_load: + # ckpt_manager.remote_data_load() + + # logger.info("live recovery done in %f", time.perf_counter() - time_start_live_recovery) # at the beginning of the inner steps we allow joiner to arrive. # We maybe reinit before the all reduce but only to allow leaving, not to join anymore @@ -390,9 +392,9 @@ def train(config: Config): else: loss_batch += loss.clone().detach() - dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg) + dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG) if config.optim.z_loss: - dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg) + dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) inner_optimizer.step() @@ -413,7 +415,10 @@ def train(config: Config): else: # we count the total tokens with respect to all diloco workers # might need to tweak this as some worker might fail to join the all reduce later - training_progress.total_tokens += new_tokens * elastic_device_mesh.global_pg.size() + + # training_progress.total_tokens += new_tokens * elastic_device_mesh.global_pg.size() # todo need to know the size + training_progress.total_tokens += new_tokens + remaining_cpu_ram = psutil.virtual_memory().available / (1024 * 1024 * 1024) metrics = { @@ -440,12 +445,12 @@ def train(config: Config): if tokens_per_second is not None: metrics["tokens_per_second"] = tokens_per_second metrics["mfu"] = ( - 100 * num_flop_per_token * tokens_per_second / gpu_peak_flops / world_info.local_world_size + 100 * num_flop_per_token * tokens_per_second / gpu_peak_flops / world_info.local_world_size ) log += f", tokens_per_second: {tokens_per_second:.2f}, mfu: {metrics['mfu']:.2f}" if config.diloco is not None: - metrics["num_peers"] = elastic_device_mesh.global_pg.size() + metrics["num_peers"] = 1 # elastic_device_mesh.global_pg.size() log += f", diloco_peers: {metrics['num_peers']}" if world_info.rank == 0: @@ -478,9 +483,9 @@ def train(config: Config): training_progress.outer_step += 1 if ( - config.ckpt.interval is not None - and training_progress.step > 0 - and training_progress.step % config.ckpt.interval == 0 + config.ckpt.interval is not None + and training_progress.step > 0 + and training_progress.step % config.ckpt.interval == 0 ): # we only allow to checkpoint after a outer step. For non diloco training outer step = 1 anyway @@ -496,10 +501,10 @@ def train(config: Config): if config.diloco: tokens_per_second = ( - config.optim.batch_size - * config.diloco.inner_steps - * config.data.seq_length - / (time.perf_counter() - time_start_outer) + config.optim.batch_size + * config.diloco.inner_steps + * config.data.seq_length + / (time.perf_counter() - time_start_outer) ) mfu = 100 * num_flop_per_token * tokens_per_second / gpu_peak_flops / world_info.local_world_size logger.info(f"effective mfu: {mfu}") @@ -531,7 +536,8 @@ def train(config: Config): ckpt_manager.wait_for_blocking_job() - del elastic_device_mesh # allow to clean up for smoother tests transition + pccl_communicator.destroy() + del pccl_communicator # allow to clean up for smoother tests transition logger.info("Training finished, exiting ...") diff --git a/third_party/gloo b/third_party/gloo deleted file mode 160000 index 5354032e..00000000 --- a/third_party/gloo +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 5354032ea08eadd7fc4456477f7f7c6308818509 diff --git a/uv.lock b/uv.lock index e127728c..77ac3534 100644 --- a/uv.lock +++ b/uv.lock @@ -1125,6 +1125,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/22/a5/a0b255295406ed54269814bc93723cfd1a0da63fb9aaf99e1364f07923e5/pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23", size = 11498828 }, ] +[[package]] +name = "pccl" +version = "0.1.0" +source = { git = "ssh://git@github.com/PrimeIntellect-ai/pccl.git?subdirectory=bindings%2Fpython&rev=16110e15#16110e15afe132be0803a9f651734c871e6b074a" } + [[package]] name = "platformdirs" version = "4.3.2" @@ -2166,6 +2171,7 @@ dependencies = [ { name = "fsspec", extra = ["gcs"] }, { name = "ninja" }, { name = "numpy" }, + { name = "pccl" }, { name = "psutil" }, { name = "pyarrow" }, { name = "pydantic-config" }, @@ -2202,6 +2208,7 @@ requires-dist = [ { name = "fsspec", extras = ["gcs"], specifier = ">=2024.3.1" }, { name = "ninja" }, { name = "numpy" }, + { name = "pccl", git = "ssh://git@github.com/PrimeIntellect-ai/pccl.git?subdirectory=bindings%2Fpython&rev=16110e15" }, { name = "psutil" }, { name = "pyarrow" }, { name = "pydantic-config", git = "https://github.com/samsja/pydantic_config.git?rev=e529c9c" },