From 45b0bcee9b3e98b1efae2919bbf050ff84758e56 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 21 Nov 2024 19:05:26 -0500 Subject: [PATCH 01/26] Add distributed datasets --- .../stateful_dataloader/ibm_rescalable.py | 604 ++++++++++++++++++ 1 file changed, 604 insertions(+) create mode 100644 torchdata/stateful_dataloader/ibm_rescalable.py diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py new file mode 100644 index 000000000..f457d9d0b --- /dev/null +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -0,0 +1,604 @@ +import logging +import math +import os +from copy import deepcopy +from typing import Any, Callable, List + +import torch +import torch.distributed as dist +import torch.distributed.tensor as dtensor +import torch.utils.data as data + +from .stateful_dataloader import StatefulDataLoader + +""" +The following distributed dataloaders are designed around 3 main principles: + +1. Efficient, asynchronous operation. Workers on different devices do not communicate. +2. Modularity. Data loading pipeline is composed of wrapped iterators, the base iterator + loading from disk and additional layers adding levels of post-processing (shuffling, + packing, padding, rescaling, etc.). +3. Seamless resumption from checkpoint. Each stage of the pipeline maintains an internal + state that can be written/read on disk via implemented recursive `state_dict()` and + `load_state_dict()` calls. Any values that should be saved to state can be designated + 'state_params' and will be automatically included in the state dict. States must be + valid targets of torch.tensor(). +4. Rescalability. Users can save and load checkpoints to/from different numbers of workers + without losing the global state. This is accomplished by splitting the global state over + a predefined large number of small partitions, each of which tracks its own individual + state. Rescaling is accomplished by re-distributing these shards over the physical workers. + +Our loaders obey the following type hierarchy: +torch.data.IterableDataset -> _StatefulDataset -> _WrapperDataset. +`_StatefulDataset` implements state and checkpointing logic. A `_WrapperDataset` holds a +single `_StatefulDataset` and iterates via calling its wrapped dataset any number of times, +then applying some sort of post-processing and yielding the result. Users build data processing +pipelines by wrapping a base `_StatefulDataset` in any number of `_WrapperDataset` layers, +which is then passed to the torch DataLoader. + +It is likely that this can be merged into the existing Nodes structure, but we leave this for +future work, for now. +""" + + +def _shard_partition(itemlist: List[Any], rank: int, worldsize: int) -> List[Any]: + """ + Partition itemlist into worldsize chunks, grab chunk corresponding to rank and return. + """ + return itemlist[(rank * len(itemlist)) // worldsize : ((rank + 1) * len(itemlist)) // worldsize] + + +def _shard_inclusive(itemlist: List[Any], rank: int, worldsize: int) -> List[Any]: + """ + In cases where len(itemlist) % worldsize != 0, allow for fractional ownership of items, + and return the span including all owned items, fractional or otherwise. + """ + start = math.floor(len(itemlist) * rank / worldsize) + end = math.ceil(len(itemlist) * (rank + 1) / worldsize) + return itemlist[start:end] + + +class _StatefulDataset(data.IterableDataset): + """ + Stub for stateful datasets, extends data.IterableDataset with state_dict methods. + All subclasses should specify the params to be considered stateful via self.state_params. + """ + + def __init__( + self, + datapath: str, + rank: int, + worldsize: int, + ): + assert rank >= 0, f"Rank {rank} must be a positive integer" + assert worldsize > rank, f"Worldsize {worldsize} must be greater than rank {rank}" + assert datapath is None or ( + os.path.isdir(datapath) and len(os.listdir(datapath)) > 0 + ), f"Data path {datapath} must be a non-empty folder or None" + self.state_params: List[str] = [] + + # Default fields + self.datapath = datapath + self.rank = rank + self.worldsize = worldsize + self.local_worldsize = -1 + + # Setup / loading flags + self.is_setup = False + + def setup(self): + """ + This method should contain all setup depending on datapath or rank. + It is called after init, but immediately before any other operation. + Certain operations higher up in the pipeline may change rank or datapath + after init (for example, wrapping in a subdataset sampler layer, or copying + to worker processes), so all rank- and datapth- dependent ops are deferred to + this function. + Currently, this function simply adjusts rank/worldsize to account for + multiprocess dataloaders. + """ + if not self.is_setup: + self.is_setup = True + # Perform adjustment only if not already adjusted (i.e. via _WrapperDataset) + if self.local_worldsize == -1: + info = data.get_worker_info() + if info is None or info.num_workers == 1: + # No multi-worker rank adjustment needed + self.local_worldsize = 1 + else: + self.local_worldsize = info.num_workers + self.worldsize = self.worldsize * self.local_worldsize + self.rank = self.local_worldsize * self.rank + info.id + + def statename(self, x: str): + # Note that this naming convention implicitly disallows repeated layers in the dataset pipeline + return self.__class__.__name__ + "." + x + + def state_dict(self): + """ + Retrieve all state_params (each worker/process produces its own state dict shard). + On the off chance that you're saving a checkpoint with zero steps, run setup first. + """ + self.setup() + return {self.statename(flag): getattr(self, flag) for flag in self.state_params} + + def load_state_dict(self, state_dict): + """ + Run setup if needed, and apply all applicable state_params from the state_dict. + """ + self.setup() + [setattr(self, flag, state_dict[self.statename(flag)]) for flag in self.state_params] + + +class _WrapperDataset(_StatefulDataset): + """ + Stub for nested wrappers of _StatefulDatasets. Extends state fns with recursion. + Requires a single instantiated sub-dataset (which may be replicated during setup fn). + """ + + def __init__( + self, + dataset: _StatefulDataset, + ): + self.dataset = dataset + # Inherit default flags from sub-dataset + super().__init__(self.dataset.datapath, self.dataset.rank, self.dataset.worldsize) + + def setup(self): + """ + Datapath/rank/worldsize percolate upwards recursively during initialization, so + now we project any desired changes downward, also recursively. + We also project local_worldsize downward to prevent subsequent layers from + further inflating the rank/worldsize - we only need to account for multiprocessing once! + Any code overriding this function should still include this functionality. + """ + if not self.is_setup: + super().setup() + self.dataset.datapath = self.datapath + self.dataset.rank = self.rank + self.dataset.worldsize = self.worldsize + self.dataset.local_worldsize = self.local_worldsize + self.dataset.setup() + + def load_state_dict(self, state_dict): + """ + Sets all specified flags at the current level, then recurses into wrapped dataset. + """ + self.setup() + super().load_state_dict(state_dict) + self.dataset.load_state_dict(state_dict) + + def state_dict(self): + """ + Fetches state dict recursively from wrapped layers, then adds specified flags. + Overlapping flags are overwritten with a warning. + """ + self.setup() + out = self.dataset.state_dict() + state = super().state_dict() + for flag in self.state_params: + if flag in out: + logging.warning( + f"Loader {self.rank}: flag {flag} already present in state_dict with value {out[flag]}. " + + f"Overwriting with value {state[flag]}" + ) + out.update(state) + return out + + +#### ------------------------- DATASET LAYERS ------------------------- #### + + +class PreprocessDataset(_WrapperDataset): + """ + Wrapper for a _StatefulDataset that applies a specified preprocessing + or augmentation function to dataset outputs. + ... + Args + ---- + dataset : _StatefulDataset + Fully instantiated dataset + aug_fn : function (any -> any) + The augmentation function to apply to each dataset item. + """ + + def __init__( + self, + dataset: _StatefulDataset, + aug_fn: Callable, + ): + super().__init__(dataset) + self.aug_fn = aug_fn + + def __iter__(self): + dataset = iter(self.dataset) + while True: + out = next(dataset) + yield self.aug_fn(out) + + +class SamplingDataset(_WrapperDataset): + """ + A _WrapperDataset implementing percentage-based sampling: weights can be floats, and the + number of tokens seen from each subdataset will match those weights as closely as possible. + This is accomplished by maintaining a _StatefulDataset for each subdataset, and tracking + the number of tokens emitted by each. Whichever loader is furthest from its target will be + the next to pass a document. + Relies on eos token to determine document boundaries, so must sit below BufferDataset. + ... + Args + ---- + datapath : str + Absolute path to the dataset directory. Expects directory to contain subfolders, + which in turn contain shard files. + dataset : _StatefulDataset + Fully instantiated dataset. Cloned across desired subdatasets during setup. + delimiter_token : Any + Token used to indicate sequence/document breaks. Type should match data type. + datasets : list[str] | None + A list of subdatasets to draw from. If None, draws from all subfolders of datapath. + weights : list(float) | None + Weights describing what percent of emitted tokens should come from each subdataset. + Need not sum to 1. If None, tokens are drawn evenly. + verbose : bool + Track setup progress? + """ + + def __init__( + self, + datapath: str, + dataset: _StatefulDataset, + delimiter_token: Any, + datasets=None, + weights=None, + verbose=False, + ): + super().__init__(dataset) + self.datapath = datapath + self.delimiter = delimiter_token + self.verbose = verbose + self.datasets = ( + datasets + if datasets is not None + else [f for f in os.listdir(datapath) if not os.path.isfile(os.path.join(datapath, f)) and "meta" not in f] + ) + assert len(self.datasets) > 0, "You must specify at least one dataset" + + if weights is not None: + assert len(weights) == len( + self.datasets + ), f"Number of oversample weights {len(weights)} must match number of datasets {len(self.datasets)}" + for w in weights: + assert w > 0, f"Sampling rate {w} must be positive" + self.weights = [1] * len(self.datasets) if weights is None else weights + self.weights = [w / sum(self.weights) for w in self.weights] + + self.tokens_seen = [0] * len(self.datasets) + + self.current_iterator = -1 + self.state_params = ["tokens_seen", "current_iterator"] + + def setup(self): + if not self.is_setup: + _StatefulDataset.setup(self) + # Build subdataset iterators + self.data = [] + for i, d in enumerate(self.datasets): + self.data.append(deepcopy(self.dataset)) + self.data[-1].datapath = os.path.join(self.datapath, d) + self.data[-1].rank = self.rank + self.data[-1].worldsize = self.worldsize + self.data[-1].local_worldsize = self.local_worldsize + if self.verbose: + logging.info( + f"Worker {self.rank} assembled subdataset iterator for {d}, {i+1} of {len(self.datasets)}" + ) + [d.setup() for d in self.data] + + def __iter__(self): + self.setup() + # Grab one doc at a time in random order + data = [iter(d) for d in self.data] + while True: + if self.current_iterator != -1: + # Finish current document + out = next(data[self.current_iterator]) + self.tokens_seen[self.current_iterator] += len(out) + if out[-1] == self.delimiter: + self.current_iterator = -1 + yield out + else: + # Choose new subdataset to draw from + # (whichever is currently most underrepresented compared to target rate) + offset = [ + self.weights[i] - self.tokens_seen[i] / (sum(self.tokens_seen) + 1e-9) + for i in range(len(self.datasets)) + ] + offset_argmax = max((diff, i) for i, diff in enumerate(offset))[1] + self.current_iterator = offset_argmax + + def state_dict(self): + self.setup() + # Manually add state of all subloaders to self state + iterator_states = [d.state_dict() for d in self.data] + assert len(iterator_states) > 0, f"Worker {self.rank} owns no datasets" + # Flip list[dict[any]] to dict[list[any]] + prefix = self.statename("states.") + out = {prefix + k: [d[k] for d in iterator_states] for k in iterator_states[0].keys()} + out.update(_StatefulDataset.state_dict(self)) + return out + + def load_state_dict(self, state_dict): + self.setup() + # Load stats + _StatefulDataset.load_state_dict(self, state_dict) + # Load sub-iterator states + prefix = self.statename("states.") + # Flip dict[list[any]] to list[dict[any]] + iterator_states = [ + {k[k.find(prefix) + len(prefix) :]: v[i] for k, v in state_dict.items() if prefix in k} + for i in range(len(self.data)) + ] + # Load individual state sub-dicts + [self.data[i].load_state_dict(iterator_states[i]) for i in range(len(self.data))] + + +class DummyDataset(_StatefulDataset): + """ + A dummy base dataset for demo purposes. + + Normally this dataset would be responsible for using rank, datapath and worldsize arguments + to perform dataset partitioning, and implement repeating iteration over its particular data shard. + + Spits out random sequences of desired vocab size / seq length as lists. + Places delimiter token at end of each sequence (used by SamplingDataset). + """ + + def __init__( + self, + datapath: str, + rank: int, + worldsize: int, + delimiter_token: Any, + seed: int = 42, + vocab: int = 100, + seqlen: int = 64, + ): + super().__init__(datapath, rank, worldsize) + self.vocab = vocab + self.seqlen = seqlen + self.delimiter = delimiter_token + # Ensure different seeds across ranks and datasets, for demo purposes + seed = seed + self.rank + len(datapath) * 100 + self.generator = torch.Generator().manual_seed(seed) + self.g_state = None + self.state_params = ["g_state"] + + def __iter__(self): + while True: + out = torch.rand(self.seqlen, generator=self.generator) + out = out.mul(self.vocab).int().tolist() + out[-1] = self.delimiter + yield out + + def state_dict(self): + # Write generator state manually + self.g_state = self.generator.get_state().tolist() + return super().state_dict()() + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + # Manually set generator state if it exists + if self.g_state is not None: + self.generator.set_state(torch.tensor(self.g_state, dtype=torch.uint8)) + + +class ScalableShardDataset(_WrapperDataset): + """ + A _WrapperDataset implementing rescalability: loading from checkpoint into a different + number of gpus will nonetheless keep avoiding all data previously seen in the current epoch. + This is accomplished by maintaining a large number of smaller StatefulDatasets, cloned from the + original dataset arg with adjusted ranks, which track state individually and reshard over n_gpus. + Rescaling only works when this layer wraps all other layers that contribute to state_dict. + ... + Args + ---- + dataset : _StatefulDataset + Fully instantiated dataset. Cloned into logical workers during setup fn. + n_logical_shards : int + Total number of logical shards. Must be a multiple of world size. + verbose : bool + Track setup progress? + """ + + def __init__( + self, + dataset: _StatefulDataset, + n_logical_shards: int = 2048, + verbose=False, + ): + super().__init__(dataset) + assert ( + n_logical_shards % self.worldsize == 0 + ), f"World size {self.worldsize} must divide n_logical_shards {n_logical_shards} evenly" + assert n_logical_shards > 0, f"n_logical_shards {n_logical_shards} must be a positive integer" + + self.total_shards = n_logical_shards + self.verbose = verbose + + # Fields to be populated during setup / subdataset setup + self.data: List[_StatefulDataset] = [] + self.logicals_owned: List[int] = [] + self.n_logicals = 0 + + # Position "state", used only for maintaining order when n_workers is unchanged + # For scaling up or down, logical position is meaningless, and reset + self.current_reader = 0 + self.load_worldsize = self.worldsize + + self.state_params = ["current_reader"] # self.data states are handled manually + + def setup(self): + if not self.is_setup: + _StatefulDataset.setup(self) + n_logical_shards = self.total_shards + logicals = list(range(n_logical_shards)) + self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize) + self.n_logicals = n_logical_shards // self.worldsize + assert ( + len(self.logicals_owned) == self.n_logicals + ), "(world size * num workers) does not divide logical shards evenly" + + # Build logical shards + for i in range(self.n_logicals): + self.data.append(deepcopy(self.dataset)) + self.data[-1].worldsize = n_logical_shards + self.data[-1].rank = self.logicals_owned[i] + self.data[-1].local_worldsize = 1 + self.data[-1].datapath = self.datapath + self.data[-1].verbose = self.rank == 0 + if self.verbose: + logging.info( + f"Worker {self.rank} assembled logical shard {self.logicals_owned[i]}, {i+1} of {self.n_logicals}" + ) + [d.setup() for d in self.data] + + def __iter__(self): + self.setup() + # Grab one item at a time, iterating over owned logical shards + data = [iter(d) for d in self.data] + while True: + ind = self.current_reader + # Read doc + out = next(data[ind]) + # Update state + self.current_reader = (self.current_reader + 1) % self.n_logicals + yield out + + def state_dict(self): + self.setup() + # Recursive fetch + logical_shard_states = [d.state_dict() for d in self.data] + assert len(logical_shard_states) > 0, f"Worker {self.rank} owns no shards???" + # Flip list[dict[Any]] to dict[list[Any]] + state_dict = {k: [d[k] for d in logical_shard_states] for k in logical_shard_states[0].keys()} + state_dict.update(_StatefulDataset.state_dict(self)) + + # Convert to tensor form + out = {} + for k, v in state_dict.items(): + v = torch.tensor(v) + if len(v.shape) == 0: + k = k + ".scalar" + v = v.unsqueeze(0) + out[k] = v + + return out + + def load_state_dict(self, state_dict): + self.setup() + + # Convert back to lists and scalars + def detorchify(k, v): + v = v.tolist() + if ".scalar" in k: + k = k[:-7] + v = v[0] + return k, v + + plain_dict = {} + for k, v in state_dict.items(): + k, v = detorchify(k, v) + plain_dict[k] = v + state_dict = plain_dict + + # Assemble logical shard states + # TODO: how is this handling non-resharding state_params when resharding??? + _StatefulDataset.load_state_dict(self, state_dict) + # Remove all non-resharding state + [state_dict.pop(self.statename(n)) for n in self.state_params] + # Flip dict[list[any]] to list[dict[any]] + logical_shard_states = [{k: v[i] for k, v in state_dict.items()} for i in range(self.n_logicals)] + + # Load values + for i in range(self.n_logicals): + self.data[i].load_state_dict(logical_shard_states[i]) + + +#### ------------------------- CHECKPOINT FUNCTIONS ------------------------- #### + + +def __pop_dstate(state, device_mesh, placements): + """ + Removes worker states from the StatefulDataLoader state dict, and assembles them + into a separate list of dicts for distributed checkpointing. + """ + dstate = state["_snapshot"]["_worker_snapshots"] + dstate = [dstate[f"worker_{i}"].pop("dataset_state") for i in range(len(dstate))] + # Flip list[dict[tensor]] to dict[list[tensor]], and concat + dstate = {k: torch.cat([d[k] for d in dstate], 0) for k in dstate[0]} + # Construct dtensors from tensors + dstate = { + k: dtensor.DTensor.from_local( + v, + device_mesh, + placements, + ) + for k, v in dstate.items() + } + return dstate + + +def save_distributed_state_dict( + loader: StatefulDataLoader, + path: str, + device_mesh=None, +): + rank = loader.dataset.rank + state = deepcopy(loader.state_dict()) + dstate = __pop_dstate(state, device_mesh, [dtensor.placement_types.Shard(0)]) + # Write distributed state dict + writer = dist.checkpoint.FileSystemWriter(path) + dist.checkpoint.save( + dstate, + writer, + ) + # Write nondistributed state dict + torch.save(state, os.path.join(path, f"__nondist_cp_{rank}.pth")) + + +def load_distributed_state_dict( + loader: StatefulDataLoader, + path: str, + device_mesh=None, +): + base = loader.state_dict() + nworkers = base["_snapshot"]["_main_snapshot"]["_num_workers"] + rank = loader.dataset.rank + dstate = __pop_dstate(base, device_mesh, [dtensor.placement_types.Shard(0)]) # placements) + # Read nondistributed state dict + ckp_ws = 0 if not os.path.exists(path) else len([x for x in os.listdir(path) if "__nondist_cp_" in x]) + # Check that number of loaders matches + if ckp_ws == loader.dataset.worldsize: + state = torch.load(os.path.join(path, f"__nondist_cp_{rank}.pth")) + # Check that number of workers matches + if nworkers != state["_snapshot"]["_main_snapshot"]["_num_workers"]: + state = base + else: + # On mismatch, discard saved non-reshardable loader state and start fresh + state = base + # Read distributed state dict + reader = dist.checkpoint.FileSystemReader(path) + dist.checkpoint.load_state_dict( + dstate, + reader, + ) + # Get local tensors from dtensors, and slice over workers + dstate = {k: v.to_local().chunk(nworkers) for k, v in dstate.items()} + # Flip dict[list[tensor]] to list[dict[tensor]] + dstate = [{k: v[i] for k, v in dstate.items()} for i in range(nworkers)] + # Re-insert worker states into loader state + for i in range(nworkers): + state["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"] = dstate[i] + # Load into loader + loader.load_state_dict(state) From e486614c757aee390c57d073e1b14fc341b669f5 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Thu, 21 Nov 2024 19:09:19 -0500 Subject: [PATCH 02/26] Formatting, commenting --- torchdata/stateful_dataloader/ibm_rescalable.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index f457d9d0b..99328e1d7 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -554,6 +554,13 @@ def save_distributed_state_dict( path: str, device_mesh=None, ): + """ + Retrieves dataloader state dict, and separates worker states from loader state. + Loader state is not rescalable, and is saved using normal torch.save. + It is discarded when rescaling. + Rescalable worker states are compiled into a dtensor across ranks, and saved + using pytorch distributed checkpointing. + """ rank = loader.dataset.rank state = deepcopy(loader.state_dict()) dstate = __pop_dstate(state, device_mesh, [dtensor.placement_types.Shard(0)]) @@ -572,6 +579,13 @@ def load_distributed_state_dict( path: str, device_mesh=None, ): + """ + Retrieves dataloader state dict, and separates worker states from loader state. + If not rescaling, load saved dataloader state. + Rescalable worker states are retrieved using pytorch distributed checkpointing. + States are distributed over workers, and ScalableShardDataset will handle + partitioning and re-assignment of available states into logical ranks. + """ base = loader.state_dict() nworkers = base["_snapshot"]["_main_snapshot"]["_num_workers"] rank = loader.dataset.rank From 10e45b9ff0e4d65d07dcb940153d708f5fa31416 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 19:58:27 -0500 Subject: [PATCH 03/26] Add demo script --- examples/ibm_rescaling/rescaling_demo.py | 146 +++++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 examples/ibm_rescaling/rescaling_demo.py diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py new file mode 100644 index 000000000..8d71b35d7 --- /dev/null +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -0,0 +1,146 @@ +import argparse +import os + +import torch +from torch import distributed as dist + +from torchdata.stateful_dataloader import StatefulDataLoader +from torchdata.stateful_dataloader.ibm_rescalable import ( + DummyDataset, + PreprocessDataset, + SamplingDataset, + ScalableShardDataset, + load_distributed_state_dict, + save_distributed_state_dict, +) + +# This example script validates the rescaling behavior of the ibm rescalable distributed datasets. +# On first run, saves a distributed checkpoint to the desired location. +# On subsequent runs, loads the checkpoint (possibly on a different world size / num workers) +# and verifies that previous data is not revisited, while upcoming data is. + +# Example usage: +# torchrun [torchrun args] examples/ibm_rescaling/rescaling_demo.py --ckpt_path=~/ckpts/rescale_test --logical_shards=48 --num_workers=6 + + +parser = argparse.ArgumentParser(description="Script to validate rescaling of dataloader checkpoints") +parser.add_argument("--ckpt_path", type=str, default="./rescale_test") +parser.add_argument( + "--logical_shards", + type=int, + default=96, + help="Total number of data partitions. (worldsize * n_workers) must divide this evenly.", +) +parser.add_argument("--num_workers", type=int, default=1, help="Number of dataloader workers per device") +parser.add_argument("--b_size", type=int, default=1, help="Number of data points per step per device") +parser.add_argument("--seed", type=int, default=42) + +args = parser.parse_args() + +# Setup +rank = int(os.getenv("RANK", 0)) +world_size = int(os.getenv("WORLD_SIZE", 1)) +dist.init_process_group() +mesh = dist.device_mesh.init_device_mesh("cpu", [world_size]) +placement = [dist.tensor.placement_types.Shard(0)] + +# Build dataloader +data = DummyDataset("not_a_real_datapath", rank, world_size, delimiter_token=-1, seed=args.seed) +# Pretend that we're sampling over multiple sub-datasets +data = SamplingDataset( + "not_a_real_datapath", + data, + delimiter_token=-1, + datasets=["sub_dataset", "second_subdataset", "small_subdataset"], + weights=[12, 17, 5], +) +# Apply rescalability layer +data = ScalableShardDataset(data, n_logical_shards=args.logical_shards) +# Statelessly convert all outputs to tensors +data = PreprocessDataset(data, torch.tensor) +# Wrap in StatefulDataLoader +data = StatefulDataLoader(data, batch_size=args.b_size, num_workers=args.num_workers) + +# If checkpoint does not exist, create it +if not os.path.exists(args.ckpt_path) or len(os.listdir(cfg.ckpt_save_path)) == 0: + os.makedirs(args.ckpt_path, exist_ok=True) + # Iterate, assemble values to exclude + if rank == 0: + print("No existing checkpoint. Processing 100 steps.") + + avoid = [] + for i, inp in enumerate(data): + if i == 100: + if rank == 0: + print("Iteration complete!") + save_distributed_state_dict(data, os.path.join(args.ckpt_path, "loader_dcp_state"), mesh) + break + avoid.append(inp) + avoid = torch.cat(avoid) + # Get all vals onto each rank + avoid = dist.tensor.DTensor.from_local( + avoid, + mesh, + placement, + ).full_tensor() + + # Continue, assemble values to include + load_distributed_state_dict(data, os.path.join(args.ckpt_path, "loader_dcp_state"), mesh) + if rank == 0: + print("DCP state loaded!") + + include = [] + for i, inp in enumerate(data): + if i == 10: + break + include.append(inp) + include = torch.cat(include) + if rank == 0: + print("Iteration round 2 complete!") + # Get all vals onto each rank + include = dist.tensor.DTensor.from_local(include, mesh, placement).full_tensor() + + if rank == 0: + torch.save(avoid, os.path.join(args.ckpt_path, "avoid.pth")) + torch.save(include, os.path.join(args.ckpt_path, "include.pth")) + print( + "Generation complete! Please rerun (with different world size / workers if desired) to complete the check." + ) + +# If checkpoint does exist, load and take 100 steps. +# Ensure avoid values are avoided, and all include values are included. +else: + if rank == 0: + print("Checkpoint detected!") + load_distributed_state_dict(data, os.path.join(args.ckpt_path, "loader_dcp_state"), mesh) + + vals = [] + for i, inp in enumerate(data): + if i == 100: + break + vals.append(inp) + vals = torch.cat(vals) + # Get all vals onto each rank + vals = dist.tensor.DTensor.from_local(vals, mesh, placement).full_tensor() + + # Perform avoid/include checks on rank 0 only + if rank == 0: + avoid = torch.load(os.path.join(args.ckpt_path, "avoid.pth")) + include = torch.load(os.path.join(args.ckpt_path, "include.pth")) + + def _in(v, m): + # Returns whether vector v is a row of matrix m (both tensors) + return m.sub(v[None]).abs().sum(1).sign().prod().bool().logical_not().item() + + # Avoid check + for i, x in enumerate(avoid.split(1)): + assert not _in(x[0], vals), i + print("Check passed: seen data was not revisited!") + + # Include check + for i, x in enumerate(include.split(1)): + assert _in(x[0], vals), i + print("Check passed: upcoming data appears as expected!") + +dist.barrier() +dist.destroy_process_group() From 10a6f66ec857fb0a2b92d155ddfedad2da9a45d3 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 20:13:34 -0500 Subject: [PATCH 04/26] Datapath None --- examples/ibm_rescaling/rescaling_demo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 8d71b35d7..bd9885800 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -45,10 +45,10 @@ placement = [dist.tensor.placement_types.Shard(0)] # Build dataloader -data = DummyDataset("not_a_real_datapath", rank, world_size, delimiter_token=-1, seed=args.seed) +data = DummyDataset(None, rank, world_size, delimiter_token=-1, seed=args.seed) # Pretend that we're sampling over multiple sub-datasets data = SamplingDataset( - "not_a_real_datapath", + None, data, delimiter_token=-1, datasets=["sub_dataset", "second_subdataset", "small_subdataset"], From 02818977f358dd48a0ed8ae1ec84c7012e24bfcd Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 20:20:00 -0500 Subject: [PATCH 05/26] Shift dummydata seeding to setup, dummy path handling --- examples/ibm_rescaling/rescaling_demo.py | 2 +- torchdata/stateful_dataloader/ibm_rescalable.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index bd9885800..5e99068b8 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -48,7 +48,7 @@ data = DummyDataset(None, rank, world_size, delimiter_token=-1, seed=args.seed) # Pretend that we're sampling over multiple sub-datasets data = SamplingDataset( - None, + "", data, delimiter_token=-1, datasets=["sub_dataset", "second_subdataset", "small_subdataset"], diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 99328e1d7..ee8571cac 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -369,11 +369,15 @@ def __init__( self.seqlen = seqlen self.delimiter = delimiter_token # Ensure different seeds across ranks and datasets, for demo purposes - seed = seed + self.rank + len(datapath) * 100 - self.generator = torch.Generator().manual_seed(seed) + self.seed = seed + self.generator = None self.g_state = None self.state_params = ["g_state"] + def setup(self): + super().setup() + self.generator = torch.Generator().manual_seed(self.seed + self.rank + len(self.datapath) * 100) + def __iter__(self): while True: out = torch.rand(self.seqlen, generator=self.generator) From a175c3c8d7745c4e843bdb9e5f85e55820e58730 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 20:23:53 -0500 Subject: [PATCH 06/26] Actually create dummy data folders --- examples/ibm_rescaling/rescaling_demo.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 5e99068b8..9d05bb3ee 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -47,11 +47,13 @@ # Build dataloader data = DummyDataset(None, rank, world_size, delimiter_token=-1, seed=args.seed) # Pretend that we're sampling over multiple sub-datasets +subdatas = ["sub_dataset", "second_subdataset", "small_subdataset"] +[os.makedirs(os.path.join(args.ckpt_path, "data", subdata), exist_ok=True) for subdata in subdatas] data = SamplingDataset( - "", + os.path.join(args.ckpt_path, "data"), data, delimiter_token=-1, - datasets=["sub_dataset", "second_subdataset", "small_subdataset"], + datasets=subdatas, weights=[12, 17, 5], ) # Apply rescalability layer From 957a5bf7af392d63088921e2bb4ac4e788c39a6f Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 20:26:04 -0500 Subject: [PATCH 07/26] Remove cfg ref --- examples/ibm_rescaling/rescaling_demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 9d05bb3ee..241bc3848 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -64,7 +64,7 @@ data = StatefulDataLoader(data, batch_size=args.b_size, num_workers=args.num_workers) # If checkpoint does not exist, create it -if not os.path.exists(args.ckpt_path) or len(os.listdir(cfg.ckpt_save_path)) == 0: +if not os.path.exists(args.ckpt_path) or len(os.listdir(args.ckpt_path)) == 0: os.makedirs(args.ckpt_path, exist_ok=True) # Iterate, assemble values to exclude if rank == 0: From 2e9bdf09014e2c6a86877279e415b4d9d3a86183 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 20:28:01 -0500 Subject: [PATCH 08/26] Remove double () call --- torchdata/stateful_dataloader/ibm_rescalable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index ee8571cac..60c0ec331 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -388,7 +388,7 @@ def __iter__(self): def state_dict(self): # Write generator state manually self.g_state = self.generator.get_state().tolist() - return super().state_dict()() + return super().state_dict() def load_state_dict(self, state_dict): super().load_state_dict(state_dict) From e475eeca02fe3b9889147a1a9b6a8d53c07b9bcf Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 20:31:25 -0500 Subject: [PATCH 09/26] Fix dist checkpoint import --- torchdata/stateful_dataloader/ibm_rescalable.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 60c0ec331..05641446e 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -5,7 +5,7 @@ from typing import Any, Callable, List import torch -import torch.distributed as dist +from torch.distributed import checkpoint import torch.distributed.tensor as dtensor import torch.utils.data as data @@ -569,8 +569,8 @@ def save_distributed_state_dict( state = deepcopy(loader.state_dict()) dstate = __pop_dstate(state, device_mesh, [dtensor.placement_types.Shard(0)]) # Write distributed state dict - writer = dist.checkpoint.FileSystemWriter(path) - dist.checkpoint.save( + writer = checkpoint.FileSystemWriter(path) + checkpoint.save( dstate, writer, ) @@ -606,8 +606,8 @@ def load_distributed_state_dict( # On mismatch, discard saved non-reshardable loader state and start fresh state = base # Read distributed state dict - reader = dist.checkpoint.FileSystemReader(path) - dist.checkpoint.load_state_dict( + reader = checkpoint.FileSystemReader(path) + checkpoint.load_state_dict( dstate, reader, ) From eac8ef61382ff2b701b94a39d237cb0d9571038a Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 20:35:48 -0500 Subject: [PATCH 10/26] Check ckp subfolder existence, not working folder --- examples/ibm_rescaling/rescaling_demo.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 241bc3848..73e63d4c3 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -64,8 +64,9 @@ data = StatefulDataLoader(data, batch_size=args.b_size, num_workers=args.num_workers) # If checkpoint does not exist, create it -if not os.path.exists(args.ckpt_path) or len(os.listdir(args.ckpt_path)) == 0: - os.makedirs(args.ckpt_path, exist_ok=True) +ckpt_path = os.path.join(args.ckpt_path, "loader_dcp_state") +if not os.path.exists(ckpt_path) or len(os.listdir(ckpt_path)) == 0: + os.makedirs(ckpt_path, exist_ok=True) # Iterate, assemble values to exclude if rank == 0: print("No existing checkpoint. Processing 100 steps.") @@ -75,7 +76,7 @@ if i == 100: if rank == 0: print("Iteration complete!") - save_distributed_state_dict(data, os.path.join(args.ckpt_path, "loader_dcp_state"), mesh) + save_distributed_state_dict(data, ckpt_path, mesh) break avoid.append(inp) avoid = torch.cat(avoid) @@ -87,7 +88,7 @@ ).full_tensor() # Continue, assemble values to include - load_distributed_state_dict(data, os.path.join(args.ckpt_path, "loader_dcp_state"), mesh) + load_distributed_state_dict(data, ckpt_path, mesh) if rank == 0: print("DCP state loaded!") @@ -114,7 +115,7 @@ else: if rank == 0: print("Checkpoint detected!") - load_distributed_state_dict(data, os.path.join(args.ckpt_path, "loader_dcp_state"), mesh) + load_distributed_state_dict(data, ckpt_path, mesh) vals = [] for i, inp in enumerate(data): From afd01699c906d366e5902c2a8cae7f515eeedae2 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 20:43:15 -0500 Subject: [PATCH 11/26] Save vals for checking --- examples/ibm_rescaling/rescaling_demo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 73e63d4c3..1589d5325 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -130,6 +130,7 @@ if rank == 0: avoid = torch.load(os.path.join(args.ckpt_path, "avoid.pth")) include = torch.load(os.path.join(args.ckpt_path, "include.pth")) + torch.save(vals, os.path.join(args.ckpt_path, "vals.pth")) def _in(v, m): # Returns whether vector v is a row of matrix m (both tensors) From 031d67cb4e3fe568419e4b50f1684d986f8d175d Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 20:50:54 -0500 Subject: [PATCH 12/26] Load dummy gen state always --- torchdata/stateful_dataloader/ibm_rescalable.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 05641446e..fd62469c6 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -392,9 +392,8 @@ def state_dict(self): def load_state_dict(self, state_dict): super().load_state_dict(state_dict) - # Manually set generator state if it exists - if self.g_state is not None: - self.generator.set_state(torch.tensor(self.g_state, dtype=torch.uint8)) + # Manually set generator state + self.generator.set_state(torch.tensor(self.g_state, dtype=torch.uint8)) class ScalableShardDataset(_WrapperDataset): From d9a575bac58080eb8be2db28052734a19dd1b4d4 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 20:58:51 -0500 Subject: [PATCH 13/26] Setup calls in dummy --- torchdata/stateful_dataloader/ibm_rescalable.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index fd62469c6..95ecb4c09 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -379,6 +379,7 @@ def setup(self): self.generator = torch.Generator().manual_seed(self.seed + self.rank + len(self.datapath) * 100) def __iter__(self): + self.setup() while True: out = torch.rand(self.seqlen, generator=self.generator) out = out.mul(self.vocab).int().tolist() @@ -386,6 +387,7 @@ def __iter__(self): yield out def state_dict(self): + self.setup() # Write generator state manually self.g_state = self.generator.get_state().tolist() return super().state_dict() From 157f90b2c43631f5abe9bd9f498666d959512078 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 21:19:40 -0500 Subject: [PATCH 14/26] Diag print --- torchdata/stateful_dataloader/ibm_rescalable.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 95ecb4c09..c3695f191 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -48,16 +48,6 @@ def _shard_partition(itemlist: List[Any], rank: int, worldsize: int) -> List[Any return itemlist[(rank * len(itemlist)) // worldsize : ((rank + 1) * len(itemlist)) // worldsize] -def _shard_inclusive(itemlist: List[Any], rank: int, worldsize: int) -> List[Any]: - """ - In cases where len(itemlist) % worldsize != 0, allow for fractional ownership of items, - and return the span including all owned items, fractional or otherwise. - """ - start = math.floor(len(itemlist) * rank / worldsize) - end = math.ceil(len(itemlist) * (rank + 1) / worldsize) - return itemlist[start:end] - - class _StatefulDataset(data.IterableDataset): """ Stub for stateful datasets, extends data.IterableDataset with state_dict methods. @@ -384,6 +374,8 @@ def __iter__(self): out = torch.rand(self.seqlen, generator=self.generator) out = out.mul(self.vocab).int().tolist() out[-1] = self.delimiter + if self.rank==0: + print(out) yield out def state_dict(self): From 91f1b148211a691a5bcea98242d542626a32c572 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 21:23:14 -0500 Subject: [PATCH 15/26] Remove sampling --- examples/ibm_rescaling/rescaling_demo.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 1589d5325..5af0f51c9 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -49,13 +49,13 @@ # Pretend that we're sampling over multiple sub-datasets subdatas = ["sub_dataset", "second_subdataset", "small_subdataset"] [os.makedirs(os.path.join(args.ckpt_path, "data", subdata), exist_ok=True) for subdata in subdatas] -data = SamplingDataset( - os.path.join(args.ckpt_path, "data"), - data, - delimiter_token=-1, - datasets=subdatas, - weights=[12, 17, 5], -) +# data = SamplingDataset( +# os.path.join(args.ckpt_path, "data"), +# data, +# delimiter_token=-1, +# datasets=subdatas, +# weights=[12, 17, 5], +# ) # Apply rescalability layer data = ScalableShardDataset(data, n_logical_shards=args.logical_shards) # Statelessly convert all outputs to tensors From b3569e34f9337f365ae47bbea0a8c451541af8b1 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 21:28:48 -0500 Subject: [PATCH 16/26] Path in dummy build --- examples/ibm_rescaling/rescaling_demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 5af0f51c9..b11e9fd43 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -45,7 +45,7 @@ placement = [dist.tensor.placement_types.Shard(0)] # Build dataloader -data = DummyDataset(None, rank, world_size, delimiter_token=-1, seed=args.seed) +data = DummyDataset("data", rank, world_size, delimiter_token=-1, seed=args.seed) # Pretend that we're sampling over multiple sub-datasets subdatas = ["sub_dataset", "second_subdataset", "small_subdataset"] [os.makedirs(os.path.join(args.ckpt_path, "data", subdata), exist_ok=True) for subdata in subdatas] From 0faea8c27a7d704f948aada11d9dcf02f9bd78c4 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 21:31:20 -0500 Subject: [PATCH 17/26] Path in dummy build --- examples/ibm_rescaling/rescaling_demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index b11e9fd43..43aa0099f 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -45,7 +45,7 @@ placement = [dist.tensor.placement_types.Shard(0)] # Build dataloader -data = DummyDataset("data", rank, world_size, delimiter_token=-1, seed=args.seed) +data = DummyDataset(os.path.join(args.ckpt_path, "data"), rank, world_size, delimiter_token=-1, seed=args.seed) # Pretend that we're sampling over multiple sub-datasets subdatas = ["sub_dataset", "second_subdataset", "small_subdataset"] [os.makedirs(os.path.join(args.ckpt_path, "data", subdata), exist_ok=True) for subdata in subdatas] From 0be44e40021831e99e669f4ddb5a2229f800b9ca Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 21:34:14 -0500 Subject: [PATCH 18/26] Scalable off --- examples/ibm_rescaling/rescaling_demo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index 43aa0099f..ed98fa518 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -57,7 +57,7 @@ # weights=[12, 17, 5], # ) # Apply rescalability layer -data = ScalableShardDataset(data, n_logical_shards=args.logical_shards) +# data = ScalableShardDataset(data, n_logical_shards=args.logical_shards) # Statelessly convert all outputs to tensors data = PreprocessDataset(data, torch.tensor) # Wrap in StatefulDataLoader From c54aed2eba3fc5bdaa1ec914d7724e8cada57660 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 21:37:28 -0500 Subject: [PATCH 19/26] Build data folder early --- examples/ibm_rescaling/rescaling_demo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index ed98fa518..a29f86fba 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -43,12 +43,12 @@ dist.init_process_group() mesh = dist.device_mesh.init_device_mesh("cpu", [world_size]) placement = [dist.tensor.placement_types.Shard(0)] +subdatas = ["sub_dataset", "second_subdataset", "small_subdataset"] +[os.makedirs(os.path.join(args.ckpt_path, "data", subdata), exist_ok=True) for subdata in subdatas] # Build dataloader data = DummyDataset(os.path.join(args.ckpt_path, "data"), rank, world_size, delimiter_token=-1, seed=args.seed) # Pretend that we're sampling over multiple sub-datasets -subdatas = ["sub_dataset", "second_subdataset", "small_subdataset"] -[os.makedirs(os.path.join(args.ckpt_path, "data", subdata), exist_ok=True) for subdata in subdatas] # data = SamplingDataset( # os.path.join(args.ckpt_path, "data"), # data, From a16ffb17042b6e472cf0ba1844bb7f03a5b6102d Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 21:46:14 -0500 Subject: [PATCH 20/26] Avoid resetting gen each state dict call --- torchdata/stateful_dataloader/ibm_rescalable.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index c3695f191..ea53c92dc 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -366,7 +366,8 @@ def __init__( def setup(self): super().setup() - self.generator = torch.Generator().manual_seed(self.seed + self.rank + len(self.datapath) * 100) + if self.generator is None: + self.generator = torch.Generator().manual_seed(self.seed + self.rank + len(self.datapath) * 100) def __iter__(self): self.setup() From b645aeaa0e4dc80130d846c065da73565b0165f2 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 21:48:42 -0500 Subject: [PATCH 21/26] Diag print off, all datasets on --- examples/ibm_rescaling/rescaling_demo.py | 16 ++++++++-------- torchdata/stateful_dataloader/ibm_rescalable.py | 2 -- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index a29f86fba..c8b15aaac 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -49,15 +49,15 @@ # Build dataloader data = DummyDataset(os.path.join(args.ckpt_path, "data"), rank, world_size, delimiter_token=-1, seed=args.seed) # Pretend that we're sampling over multiple sub-datasets -# data = SamplingDataset( -# os.path.join(args.ckpt_path, "data"), -# data, -# delimiter_token=-1, -# datasets=subdatas, -# weights=[12, 17, 5], -# ) +data = SamplingDataset( + os.path.join(args.ckpt_path, "data"), + data, + delimiter_token=-1, + datasets=subdatas, + weights=[12, 17, 5], +) # Apply rescalability layer -# data = ScalableShardDataset(data, n_logical_shards=args.logical_shards) +data = ScalableShardDataset(data, n_logical_shards=args.logical_shards) # Statelessly convert all outputs to tensors data = PreprocessDataset(data, torch.tensor) # Wrap in StatefulDataLoader diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index ea53c92dc..cbe5dd17c 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -375,8 +375,6 @@ def __iter__(self): out = torch.rand(self.seqlen, generator=self.generator) out = out.mul(self.vocab).int().tolist() out[-1] = self.delimiter - if self.rank==0: - print(out) yield out def state_dict(self): From ceffd247f48459a33bb104da6724dd8ea6500652 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Fri, 22 Nov 2024 21:57:27 -0500 Subject: [PATCH 22/26] Stop saving vals --- examples/ibm_rescaling/rescaling_demo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/ibm_rescaling/rescaling_demo.py b/examples/ibm_rescaling/rescaling_demo.py index c8b15aaac..2bb4a6bbb 100644 --- a/examples/ibm_rescaling/rescaling_demo.py +++ b/examples/ibm_rescaling/rescaling_demo.py @@ -130,7 +130,6 @@ if rank == 0: avoid = torch.load(os.path.join(args.ckpt_path, "avoid.pth")) include = torch.load(os.path.join(args.ckpt_path, "include.pth")) - torch.save(vals, os.path.join(args.ckpt_path, "vals.pth")) def _in(v, m): # Returns whether vector v is a row of matrix m (both tensors) From d2eb12ef48b6240eddcf184bf6c83859fd7403ed Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 14 Jan 2025 18:18:56 -0500 Subject: [PATCH 23/26] Attempt single blob save --- torchdata/stateful_dataloader/ibm_rescalable.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index cbe5dd17c..23338ba81 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -560,14 +560,15 @@ def save_distributed_state_dict( rank = loader.dataset.rank state = deepcopy(loader.state_dict()) dstate = __pop_dstate(state, device_mesh, [dtensor.placement_types.Shard(0)]) + out = {"state":state, "dstate":dstate} # Write distributed state dict writer = checkpoint.FileSystemWriter(path) checkpoint.save( - dstate, + out, writer, ) - # Write nondistributed state dict - torch.save(state, os.path.join(path, f"__nondist_cp_{rank}.pth")) + # # Write nondistributed state dict + # torch.save(state, os.path.join(path, f"__nondist_cp_{rank}.pth")) def load_distributed_state_dict( From ada91ec02647a85575b91390d470f4e8fab64a13 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 14 Jan 2025 18:28:14 -0500 Subject: [PATCH 24/26] Attempt single blob load --- .../stateful_dataloader/ibm_rescalable.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 23338ba81..060395904 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -587,23 +587,24 @@ def load_distributed_state_dict( nworkers = base["_snapshot"]["_main_snapshot"]["_num_workers"] rank = loader.dataset.rank dstate = __pop_dstate(base, device_mesh, [dtensor.placement_types.Shard(0)]) # placements) + inp = {"state":base, "dstate":dstate} # Read nondistributed state dict - ckp_ws = 0 if not os.path.exists(path) else len([x for x in os.listdir(path) if "__nondist_cp_" in x]) + ckp_ws = 0 if not os.path.exists(path) else len(os.listdir(path)) + # Read distributed state dict + reader = checkpoint.FileSystemReader(path) + checkpoint.load_state_dict( + inp, + reader, + ) + dstate = inp["dstate"] # Check that number of loaders matches if ckp_ws == loader.dataset.worldsize: - state = torch.load(os.path.join(path, f"__nondist_cp_{rank}.pth")) # Check that number of workers matches if nworkers != state["_snapshot"]["_main_snapshot"]["_num_workers"]: - state = base + state = inp["state"] else: # On mismatch, discard saved non-reshardable loader state and start fresh state = base - # Read distributed state dict - reader = checkpoint.FileSystemReader(path) - checkpoint.load_state_dict( - dstate, - reader, - ) # Get local tensors from dtensors, and slice over workers dstate = {k: v.to_local().chunk(nworkers) for k, v in dstate.items()} # Flip dict[list[tensor]] to list[dict[tensor]] From 9bf8f3d61927e42158955160ce2b818ff383cbb6 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 14 Jan 2025 18:37:56 -0500 Subject: [PATCH 25/26] Prevent loading in place --- torchdata/stateful_dataloader/ibm_rescalable.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 060395904..4d789541f 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -585,9 +585,8 @@ def load_distributed_state_dict( """ base = loader.state_dict() nworkers = base["_snapshot"]["_main_snapshot"]["_num_workers"] - rank = loader.dataset.rank dstate = __pop_dstate(base, device_mesh, [dtensor.placement_types.Shard(0)]) # placements) - inp = {"state":base, "dstate":dstate} + inp = {"state":deepcopy(base), "dstate":dstate} # Read nondistributed state dict ckp_ws = 0 if not os.path.exists(path) else len(os.listdir(path)) # Read distributed state dict From 934d37b5995e025004dc9692fa4ecc8c45482e41 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 14 Jan 2025 18:55:39 -0500 Subject: [PATCH 26/26] Cleanup --- torchdata/stateful_dataloader/ibm_rescalable.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torchdata/stateful_dataloader/ibm_rescalable.py b/torchdata/stateful_dataloader/ibm_rescalable.py index 4d789541f..f08f9c6ef 100644 --- a/torchdata/stateful_dataloader/ibm_rescalable.py +++ b/torchdata/stateful_dataloader/ibm_rescalable.py @@ -557,7 +557,6 @@ def save_distributed_state_dict( Rescalable worker states are compiled into a dtensor across ranks, and saved using pytorch distributed checkpointing. """ - rank = loader.dataset.rank state = deepcopy(loader.state_dict()) dstate = __pop_dstate(state, device_mesh, [dtensor.placement_types.Shard(0)]) out = {"state":state, "dstate":dstate} @@ -567,8 +566,6 @@ def save_distributed_state_dict( out, writer, ) - # # Write nondistributed state dict - # torch.save(state, os.path.join(path, f"__nondist_cp_{rank}.pth")) def load_distributed_state_dict( @@ -587,8 +584,6 @@ def load_distributed_state_dict( nworkers = base["_snapshot"]["_main_snapshot"]["_num_workers"] dstate = __pop_dstate(base, device_mesh, [dtensor.placement_types.Shard(0)]) # placements) inp = {"state":deepcopy(base), "dstate":dstate} - # Read nondistributed state dict - ckp_ws = 0 if not os.path.exists(path) else len(os.listdir(path)) # Read distributed state dict reader = checkpoint.FileSystemReader(path) checkpoint.load_state_dict( @@ -597,6 +592,7 @@ def load_distributed_state_dict( ) dstate = inp["dstate"] # Check that number of loaders matches + ckp_ws = 0 if not os.path.exists(path) else len(os.listdir(path)) if ckp_ws == loader.dataset.worldsize: # Check that number of workers matches if nworkers != state["_snapshot"]["_main_snapshot"]["_num_workers"]: