Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update torchdata.nodes docs, use sphinx for API #1396

Merged
merged 6 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions docs/source/getting_started_with_torchdata_nodes.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
Getting Started With ``torchdata.nodes`` (beta)
===============================================

Install torchdata with pip.

.. code:: bash

pip install torchdata>=0.10.0

Generator Example
~~~~~~~~~~~~~~~~~

Wrap a generator (or any iterable) to convert it to a BaseNode and get started

.. code:: python

from torchdata.nodes import IterableWrapper, ParallelMapper, Loader

node = IterableWrapper(range(10))
node = ParallelMapper(node, map_fn=lambda x: x**2, num_workers=3, method="thread")
loader = Loader(node)
result = list(loader)
print(result)
# [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

Sampler Example
~~~~~~~~~~~~~~~

Samplers are still supported, and you can use your existing
``torch.utils.data.Dataset``\'s. See :ref:`migrate-to-nodes-from-utils` for an in-depth
example.

.. code:: python

import torch.utils.data
from torchdata.nodes import SamplerWrapper, ParallelMapper, Loader


class SquaredDataset(torch.utils.data.Dataset):
def __getitem__(self, i: int) -> int:
return i**2
def __len__(self):
return 10

dataset = SquaredDataset()
sampler = RandomSampler(dataset)

# For fine-grained control of iteration order, define your own sampler
node = SamplerWrapper(sampler)
# Simply apply dataset's __getitem__ as a map function to the indices generated from sampler
node = ParallelMapper(node, map_fn=dataset.__getitem__, num_workers=3, method="thread")
# Loader is used to convert a node (iterator) into an Iterable that may be reused for multi epochs
loader = Loader(node)
print(list(loader))
# [25, 36, 9, 49, 0, 81, 4, 16, 64, 1]
print(list(loader))
# [0, 4, 1, 64, 49, 25, 9, 16, 81, 36]
9 changes: 8 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,26 @@ Features described in this documentation are classified by release status:
binary distributions like PyPI or Conda, except sometimes behind run-time
flags, and are at an early stage for feedback and testing.

.. toctree::
:maxdepth: 2
:caption: Developer Notes:

what_is_torchdata_nodes.rst

.. toctree::
:maxdepth: 2
:caption: API Reference:

torchdata.stateful_dataloader.rst
torchdata.nodes.rst
torchdata.stateful_dataloader.rst


.. toctree::
:maxdepth: 2
:caption: Tutorial and Examples:

getting_started_with_torchdata_nodes.rst
migrate_to_nodes_from_utils.rst
stateful_dataloader_tutorial.rst


Expand Down
184 changes: 184 additions & 0 deletions docs/source/migrate_to_nodes_from_utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
.. _migrate-to-nodes-from-utils:

Migrating to ``torchdata.nodes`` from ``torch.utils.data``
==========================================================

This guide is intended to help people familiar with ``torch.utils.data``, or
:class:`~torchdata.stateful_dataloader.StatefulDataLoader`,
to get started with ``torchdata.nodes``, and provide a starting ground for defining
your own dataloading pipelines.

We'll demonstrate how to achieve the most common DataLoader features, re-use existing samplers and datasets,
and load/save dataloader state. It performs at least as well as ``DataLoader`` and ``StatefulDataLoader``,
see :ref:`how-does-nodes-perform`.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyanshk added xref here


Map-Style Datasets
~~~~~~~~~~~~~~~~~~

Let's look at the ``DataLoader`` constructor args and go from there

.. code:: python

class DataLoader:
def __init__(
self,
dataset: Dataset[_T_co],
batch_size: Optional[int] = 1,
shuffle: Optional[bool] = None,
sampler: Union[Sampler, Iterable, None] = None,
batch_sampler: Union[Sampler[List], Iterable[List], None] = None,
num_workers: int = 0,
collate_fn: Optional[_collate_fn_t] = None,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float = 0,
worker_init_fn: Optional[_worker_init_fn_t] = None,
multiprocessing_context=None,
generator=None,
*,
prefetch_factor: Optional[int] = None,
persistent_workers: bool = False,
pin_memory_device: str = "",
in_order: bool = True,
):
...

As a referesher, here is roughly how dataloading works in ``torch.utils.data.DataLoader``:
``DataLoader`` begins by generating indices from a ``sampler`` and creates batches of `batch_size` indices.
If no sampler is provided, then a RandomSampler or SequentialSampler is created by default.
The indices are passed to ``Dataset.__getitem__()``, and then a ``collate_fn`` is applied to the batch
of samples. If ``num_workers > 0``, it will use multi-processing to create
subprocesses, and pass the batches of indices to the worker processes, who will then call ``Dataset.__getitem__()`` and apply ``collate_fn``
before returning the batches to the main process. At that point, ``pin_memory`` may be applied to the tensors in the batch.

Now let's look at what an equivalent implementation for DataLoader might look like, built with ``torchdata.nodes``.

.. code:: python

from typing import List, Callable
import torchdata.nodes as tn
from torch.utils.data import RandomSampler, SequentialSampler, default_collate, Dataset

class MapAndCollate:
"""A simple transform that takes a batch of indices, maps with dataset, and then applies
collate.
TODO: make this a standard utility in torchdata.nodes
"""
def __init__(self, dataset, collate_fn):
self.dataset = dataset
self.collate_fn = collate_fn

def __call__(self, batch_of_indices: List[int]):
batch = [self.dataset[i] for i in batch_of_indices]
return self.collate_fn(batch)

# To keep things simple, let's assume that the following args are provided by the caller
def NodesDataLoader(
dataset: Dataset,
batch_size: int,
shuffle: bool,
num_workers: int,
collate_fn: Callable | None,
pin_memory: bool,
drop_last: bool,
):
# Assume we're working with a map-style dataset
assert hasattr(dataset, "__getitem__") and hasattr(dataset, "__len__")
# Start with a sampler, since caller did not provide one
sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset)
# Sampler wrapper converts a Sampler to a BaseNode
node = tn.SamplerWrapper(sampler)

# Now let's batch sampler indices together
node = tn.Batcher(node, batch_size=batch_size, drop_last=drop_last)

# Create a Map Function that accepts a list of indices, applies getitem to it, and
# then collates them
map_and_collate = MapAndCollate(dataset, collate_fn or default_collate)

# MapAndCollate is doing most of the heavy lifting, so let's parallelize it. We could
# choose process or thread workers. Note that if you're not using Free-Threaded
# Python (eg 3.13t) with -Xgil=0, then multi-threading might result in GIL contention,
# and slow down training.
node = tn.ParallelMapper(
node,
map_fn=map_and_collate,
num_workers=num_workers,
method="process", # Set this to "thread" for multi-threading
in_order=True,
)

# Optionally apply pin-memory, and we usually do some pre-fetching
if pin_memory:
node = tn.PinMemory(node)
node = tn.Prefetcher(node, prefetch_factor=num_workers * 2)

# Note that node is an iterator, and once it's exhausted, you'll need to call .reset()
# on it to start a new Epoch.
# Insteaad, we wrap the node in a Loader, which is an iterable and handles reset. It
# also provides state_dict and load_state_dict methods.
return tn.Loader(node)

Now let's test this out with a useless dataset, and demonstrate how state management works.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: replace useless with simple or something else :D ?


.. code:: python

class SquaredDataset(Dataset):
def __init__(self, len: int):
self.len = len
def __len__(self):
return self.len
def __getitem__(self, i: int) -> int:
return i**2

loader = NodesDataLoader(
dataset=SquaredDataset(14),
batch_size=3,
shuffle=False,
num_workers=2,
collate_fn=None,
pin_memory=False,
drop_last=False,
)

batches = []
for idx, batch in enumerate(loader):
if idx == 2:
state_dict = loader.state_dict()
# Saves the state_dict after batch 2 has been returned
batches.append(batch)

loader.load_state_dict(state_dict)
batches_after_loading = list(loader)
print(batches[3:])
# [tensor([ 81, 100, 121]), tensor([144, 169])]
print(batches_after_loading)
# [tensor([ 81, 100, 121]), tensor([144, 169])]

Let's also compare this to torch.utils.data.DataLoader, as a sanity check.

.. code:: python

loaderv1 = torch.utils.data.DataLoader(
dataset=SquaredDataset(14),
batch_size=3,
shuffle=False,
num_workers=2,
collate_fn=None,
pin_memory=False,
drop_last=False,
persistent_workers=False, # Coming soon to torchdata.nodes!
)
print(list(loaderv1))
# [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])]
print(batches)
# [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])]


IterableDatasets
~~~~~~~~~~~~~~~~

Coming soon! While you can already plug your IterableDataset into an ``tn.IterableWrapper``, some functions like
``get_worker_info`` are not currently supported yet. However we believe that often, sharding work between
multi-process workers is not actually necessary, and you can keep some sort of indexing in the main process while
only parallelizing some of the heavier transforms, similar to how Map-style Datasets work above.
Loading
Loading