Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry picking PR#1390 onto release branch #1392

Merged
merged 1 commit into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ provides state_dict and load_state_dict functionality. See
examples
[in this Colab notebook](https://colab.research.google.com/drive/1tonoovEd7Tsi8EW8ZHXf0v3yHJGwZP8M?usp=sharing).

## torchdata.nodes

torchdata.nodes is a library of composable iterators (not iterables!) that let you chain together common dataloading and
pre-proc operations. It follows a streaming programming model, although "sampler + Map-style" can still be configured if
you desire. See [torchdata.nodes main page](torchdata/nodes) for more details. Stay tuned for tutorial on
torchdata.nodes coming soon!

## Installation

### Version Compatibility
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ Features described in this documentation are classified by release status:
:caption: API Reference:

torchdata.stateful_dataloader.rst
torchdata.nodes.rst


.. toctree::
Expand Down
202 changes: 202 additions & 0 deletions docs/source/torchdata.nodes.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
torchdata.nodes
===============

What is ``torchdata.nodes``?
----------------------------

``torchdata.nodes`` is a library of composable iterators (not
iterables!) that let you chain together common dataloading and pre-proc
operations. It follows a streaming programming model, although “sampler
+ Map-style” can still be configured if you desire.

``torchdata.nodes`` adds more flexibility to the standard
``torch.utils.data`` offering, and introduces multi-threaded parallelism
in addition to multi-process (the only supported approach in
``torch.utils.data.DataLoader``), as well as first-class support for
mid-epoch checkpointing through a ``state_dict/load_state_dict``
interface.

``torchdata.nodes`` strives to include as many useful operators as
possible, however it’s designed to be extensible. New nodes are required
to subclass ``torchdata.nodes.BaseNode``, (which itself subclasses
``typing.Iterator``) and implement ``next()``, ``reset(initial_state)``
and ``get_state()`` operations (notably, not ``__next__``,
``load_state_dict``, nor ``state_dict``)

Getting started
---------------

Install torchdata with pip.

.. code:: bash

pip install torchdata>=0.10.0

Generator Example
~~~~~~~~~~~~~~~~~

Wrap a generator (or any iterable) to convert it to a BaseNode and get
started

.. code:: python

from torchdata.nodes import IterableWrapper, ParallelMapper, Loader

node = IterableWrapper(range(10))
node = ParallelMapper(node, map_fn=lambda x: x**2, num_workers=3, method="thread")
loader = Loader(node)
result = list(loader)
print(result)
# [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

Sampler Example
~~~~~~~~~~~~~~~

Samplers are still supported, and you can use your existing
``torch.utils.data.Dataset``\ s

.. code:: python

import torch.utils.data
from torchdata.nodes import SamplerWrapper, ParallelMapper, Loader


class SquaredDataset(torch.utils.data.Dataset):
def __getitem__(self, i: int) -> int:
return i**2
def __len__(self):
return 10

dataset = SquaredDataset()
sampler = RandomSampler(dataset)

# For fine-grained control of iteration order, define your own sampler
node = SamplerWrapper(sampler)
# Simply apply dataset's __getitem__ as a map function to the indices generated from sampler
node = ParallelMapper(node, map_fn=dataset.__getitem__, num_workers=3, method="thread")
# Loader is used to convert a node (iterator) into an Iterable that may be reused for multi epochs
loader = Loader(node)
print(list(loader))
# [25, 36, 9, 49, 0, 81, 4, 16, 64, 1]
print(list(loader))
# [0, 4, 1, 64, 49, 25, 9, 16, 81, 36]

What’s the point of ``torchdata.nodes``?
----------------------------------------

We get it, ``torch.utils.data`` just works for many many use cases.
However it definitely has a bunch of rough spots:

Multiprocessing sucks
~~~~~~~~~~~~~~~~~~~~~

- You need to duplicate memory stored in your Dataset (because of
Python copy-on-read)
- IPC is slow over multiprocess queues and can introduce slow startup
times
- You’re forced to perform batching on the workers instead of
main-process to reduce IPC overhead, increasing peak memory.
- With GIL-releasing functions and Free-Threaded Python,
multi-threading may not be GIL-bound like it used to be.

``torchdata.nodes`` enables both multi-threading and multi-processing so
you can choose what works best for your particular set up. Parallelism
is primarily configured in Mapper operators giving you flexibility in
the what, when, and how to parallelize.

Map-style and random-access doesn’t scale
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Current map dataset approach is great for datasets that fit in memory,
but true random-access is not going to be very performant once your
dataset grows beyond memory limitations unless you jump through some
hoops with a special sampler.

``torchdata.nodes`` follows a streaming data model, where operators are
Iterators that can be combined together to define a dataloading and
pre-proc pipeline. Samplers are still supported (see example above) and
can be combined with a Mapper to produce an Iterator

Multi-Datasets do not fit well with the current implementation in ``torch.utils.data``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The current Sampler (one per dataloader) concepts start to break down
when you start trying to combine multiple datasets. (For single
datasets, they’re a great abstraction and will continue to be
supported!)

- For multi-datasets, consider this scenario: ``len(dsA): 10``
``len(dsB): 20``. Now we want to do round-robin (or sample uniformly)
between these two datasets to feed to our trainer. With just a single
sampler, how can you implement that strategy? Maybe a sampler that
emits tuples? What if you want to swap with RandomSampler, or
DistributedSampler? How will ``sampler.set_epoch`` work?

``torchdata.nodes`` helps to address and scale multi-dataset dataloading
by only dealing with Iterators, thereby forcing samplers and datasets
together, focusing on composing smaller primitives nodes into a more
complex dataloading pipeline.

IterableDataset + multiprocessing requires additional dataset sharding
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Dataset sharding is required for data-parallel training, which is fairly
reasonable. But what about sharding between dataloader workers? With
Map-style datasets, distribution of work between workers is handled by
the main process, which distributes sampler indices to workers. With
IterableDatasets, each worker needs to figure out (through
``torch.utils.data.get_worker_info``) what data it should be returning.

Design choices
--------------

No Generator BaseNodes
~~~~~~~~~~~~~~~~~~~~~~

See https://github.com/pytorch/data/pull/1362 for more thoughts.

One difficult choice we made was to disallow Generators when defining a
new BaseNode implementation. However we dropped it and moved to an
Iterator-only foundation for a few reasons around state management:

1. We require explicit state handling in BaseNode implementations.
Generators store state implicitly on the stack and we found that we
needed to jump through hoops and write very convoluted code to get
basic state working with Generators
2. End-of-iteration state dict: Iterables may feel more natural, however
a bunch of issues come up around state management. Consider the
end-of-iteration state dict. If you load this state_dict into your
iterable, should this represent the end-of-iteration or the start of
the next iteration?
3. Loading state: If you call load_state_dict() on an iterable, most
users would expect the next iterator requested from it to start with
the loaded state. However what if iter is called twice before
iteration begins?
4. Multiple Live Iterator problem: if you have one instance of an
Iterable, but two live iterators, what does it mean to call
state_dict() on the Iterable? In dataloading, this is very rare,
however we still need to work around it and make a bunch of
assumptions. Forcing devs that are implementing BaseNodes to reason
about these scenarios is, in our opinion, worse than disallowing
generators and Iterables.

``torchdata.nodes.BaseNode`` implementations are Iterators. Iterators
define ``next()``, ``get_state()``, and ``reset(initial_state | None)``.
All re-initialization should be done in reset(), including initializing
with a particular state if one is passed.

However, end-users are used to dealing with Iterables, for example,

::

for epoch in range(5):
# Most frameworks and users don't expect to call loader.reset()
for batch in loader:
...
sd = loader.state_dict()
# Loading sd should not throw StopIteration right away, but instead start at the next epoch

To handle this we keep all of the assumptions and special end-of-epoch
handling in a single ``Loader`` class which takes any BaseNode and makes
it an Iterable, handling the reset() calls and end-of-epoch state_dict
loading.
Loading