From 6edf0402d3eaf8d2805cc586f81a8e5af8e83b3b Mon Sep 17 00:00:00 2001 From: Divyansh Khanna Date: Mon, 9 Dec 2024 20:14:31 -0800 Subject: [PATCH] add new torchdata.nodes doc file (#1390) * add new torchdata.nodes doc file * update main readme.md, move content of nodes readme.md to torchdata.nodes.rst * main readme * udpate readme --- README.md | 7 ++ docs/source/index.rst | 1 + docs/source/torchdata.nodes.rst | 202 ++++++++++++++++++++++++++++++++ 3 files changed, 210 insertions(+) create mode 100644 docs/source/torchdata.nodes.rst diff --git a/README.md b/README.md index 8355fd5e1..21c3eabb1 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,13 @@ provides state_dict and load_state_dict functionality. See examples [in this Colab notebook](https://colab.research.google.com/drive/1tonoovEd7Tsi8EW8ZHXf0v3yHJGwZP8M?usp=sharing). +## torchdata.nodes + +torchdata.nodes is a library of composable iterators (not iterables!) that let you chain together common dataloading and +pre-proc operations. It follows a streaming programming model, although "sampler + Map-style" can still be configured if +you desire. See [torchdata.nodes main page](torchdata/nodes) for more details. Stay tuned for tutorial on +torchdata.nodes coming soon! + ## Installation ### Version Compatibility diff --git a/docs/source/index.rst b/docs/source/index.rst index 18a8d21fe..46b92cabc 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -37,6 +37,7 @@ Features described in this documentation are classified by release status: :caption: API Reference: torchdata.stateful_dataloader.rst + torchdata.nodes.rst .. toctree:: diff --git a/docs/source/torchdata.nodes.rst b/docs/source/torchdata.nodes.rst new file mode 100644 index 000000000..d5197c0e0 --- /dev/null +++ b/docs/source/torchdata.nodes.rst @@ -0,0 +1,202 @@ +torchdata.nodes +=============== + +What is ``torchdata.nodes``? +---------------------------- + +``torchdata.nodes`` is a library of composable iterators (not +iterables!) that let you chain together common dataloading and pre-proc +operations. It follows a streaming programming model, although “sampler ++ Map-style” can still be configured if you desire. + +``torchdata.nodes`` adds more flexibility to the standard +``torch.utils.data`` offering, and introduces multi-threaded parallelism +in addition to multi-process (the only supported approach in +``torch.utils.data.DataLoader``), as well as first-class support for +mid-epoch checkpointing through a ``state_dict/load_state_dict`` +interface. + +``torchdata.nodes`` strives to include as many useful operators as +possible, however it’s designed to be extensible. New nodes are required +to subclass ``torchdata.nodes.BaseNode``, (which itself subclasses +``typing.Iterator``) and implement ``next()``, ``reset(initial_state)`` +and ``get_state()`` operations (notably, not ``__next__``, +``load_state_dict``, nor ``state_dict``) + +Getting started +--------------- + +Install torchdata with pip. + +.. code:: bash + + pip install torchdata>=0.10.0 + +Generator Example +~~~~~~~~~~~~~~~~~ + +Wrap a generator (or any iterable) to convert it to a BaseNode and get +started + +.. code:: python + + from torchdata.nodes import IterableWrapper, ParallelMapper, Loader + + node = IterableWrapper(range(10)) + node = ParallelMapper(node, map_fn=lambda x: x**2, num_workers=3, method="thread") + loader = Loader(node) + result = list(loader) + print(result) + # [0, 1, 4, 9, 16, 25, 36, 49, 64, 81] + +Sampler Example +~~~~~~~~~~~~~~~ + +Samplers are still supported, and you can use your existing +``torch.utils.data.Dataset``\ s + +.. code:: python + + import torch.utils.data + from torchdata.nodes import SamplerWrapper, ParallelMapper, Loader + + + class SquaredDataset(torch.utils.data.Dataset): + def __getitem__(self, i: int) -> int: + return i**2 + def __len__(self): + return 10 + + dataset = SquaredDataset() + sampler = RandomSampler(dataset) + + # For fine-grained control of iteration order, define your own sampler + node = SamplerWrapper(sampler) + # Simply apply dataset's __getitem__ as a map function to the indices generated from sampler + node = ParallelMapper(node, map_fn=dataset.__getitem__, num_workers=3, method="thread") + # Loader is used to convert a node (iterator) into an Iterable that may be reused for multi epochs + loader = Loader(node) + print(list(loader)) + # [25, 36, 9, 49, 0, 81, 4, 16, 64, 1] + print(list(loader)) + # [0, 4, 1, 64, 49, 25, 9, 16, 81, 36] + +What’s the point of ``torchdata.nodes``? +---------------------------------------- + +We get it, ``torch.utils.data`` just works for many many use cases. +However it definitely has a bunch of rough spots: + +Multiprocessing sucks +~~~~~~~~~~~~~~~~~~~~~ + +- You need to duplicate memory stored in your Dataset (because of + Python copy-on-read) +- IPC is slow over multiprocess queues and can introduce slow startup + times +- You’re forced to perform batching on the workers instead of + main-process to reduce IPC overhead, increasing peak memory. +- With GIL-releasing functions and Free-Threaded Python, + multi-threading may not be GIL-bound like it used to be. + +``torchdata.nodes`` enables both multi-threading and multi-processing so +you can choose what works best for your particular set up. Parallelism +is primarily configured in Mapper operators giving you flexibility in +the what, when, and how to parallelize. + +Map-style and random-access doesn’t scale +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Current map dataset approach is great for datasets that fit in memory, +but true random-access is not going to be very performant once your +dataset grows beyond memory limitations unless you jump through some +hoops with a special sampler. + +``torchdata.nodes`` follows a streaming data model, where operators are +Iterators that can be combined together to define a dataloading and +pre-proc pipeline. Samplers are still supported (see example above) and +can be combined with a Mapper to produce an Iterator + +Multi-Datasets do not fit well with the current implementation in ``torch.utils.data`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The current Sampler (one per dataloader) concepts start to break down +when you start trying to combine multiple datasets. (For single +datasets, they’re a great abstraction and will continue to be +supported!) + +- For multi-datasets, consider this scenario: ``len(dsA): 10`` + ``len(dsB): 20``. Now we want to do round-robin (or sample uniformly) + between these two datasets to feed to our trainer. With just a single + sampler, how can you implement that strategy? Maybe a sampler that + emits tuples? What if you want to swap with RandomSampler, or + DistributedSampler? How will ``sampler.set_epoch`` work? + +``torchdata.nodes`` helps to address and scale multi-dataset dataloading +by only dealing with Iterators, thereby forcing samplers and datasets +together, focusing on composing smaller primitives nodes into a more +complex dataloading pipeline. + +IterableDataset + multiprocessing requires additional dataset sharding +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Dataset sharding is required for data-parallel training, which is fairly +reasonable. But what about sharding between dataloader workers? With +Map-style datasets, distribution of work between workers is handled by +the main process, which distributes sampler indices to workers. With +IterableDatasets, each worker needs to figure out (through +``torch.utils.data.get_worker_info``) what data it should be returning. + +Design choices +-------------- + +No Generator BaseNodes +~~~~~~~~~~~~~~~~~~~~~~ + +See https://github.com/pytorch/data/pull/1362 for more thoughts. + +One difficult choice we made was to disallow Generators when defining a +new BaseNode implementation. However we dropped it and moved to an +Iterator-only foundation for a few reasons around state management: + +1. We require explicit state handling in BaseNode implementations. + Generators store state implicitly on the stack and we found that we + needed to jump through hoops and write very convoluted code to get + basic state working with Generators +2. End-of-iteration state dict: Iterables may feel more natural, however + a bunch of issues come up around state management. Consider the + end-of-iteration state dict. If you load this state_dict into your + iterable, should this represent the end-of-iteration or the start of + the next iteration? +3. Loading state: If you call load_state_dict() on an iterable, most + users would expect the next iterator requested from it to start with + the loaded state. However what if iter is called twice before + iteration begins? +4. Multiple Live Iterator problem: if you have one instance of an + Iterable, but two live iterators, what does it mean to call + state_dict() on the Iterable? In dataloading, this is very rare, + however we still need to work around it and make a bunch of + assumptions. Forcing devs that are implementing BaseNodes to reason + about these scenarios is, in our opinion, worse than disallowing + generators and Iterables. + +``torchdata.nodes.BaseNode`` implementations are Iterators. Iterators +define ``next()``, ``get_state()``, and ``reset(initial_state | None)``. +All re-initialization should be done in reset(), including initializing +with a particular state if one is passed. + +However, end-users are used to dealing with Iterables, for example, + +:: + + for epoch in range(5): + # Most frameworks and users don't expect to call loader.reset() + for batch in loader: + ... + sd = loader.state_dict() + # Loading sd should not throw StopIteration right away, but instead start at the next epoch + +To handle this we keep all of the assumptions and special end-of-epoch +handling in a single ``Loader`` class which takes any BaseNode and makes +it an Iterable, handling the reset() calls and end-of-epoch state_dict +loading.