Skip to content

Commit

Permalink
2024-12-12 nightly release (c0edd90)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Dec 12, 2024
1 parent 19f24de commit 5c4c403
Show file tree
Hide file tree
Showing 16 changed files with 503 additions and 258 deletions.
57 changes: 57 additions & 0 deletions docs/source/getting_started_with_torchdata_nodes.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
Getting Started With ``torchdata.nodes`` (beta)
===============================================

Install torchdata with pip.

.. code:: bash
pip install torchdata>=0.10.0
Generator Example
~~~~~~~~~~~~~~~~~

Wrap a generator (or any iterable) to convert it to a BaseNode and get started

.. code:: python
from torchdata.nodes import IterableWrapper, ParallelMapper, Loader
node = IterableWrapper(range(10))
node = ParallelMapper(node, map_fn=lambda x: x**2, num_workers=3, method="thread")
loader = Loader(node)
result = list(loader)
print(result)
# [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
Sampler Example
~~~~~~~~~~~~~~~

Samplers are still supported, and you can use your existing
``torch.utils.data.Dataset``\'s. See :ref:`migrate-to-nodes-from-utils` for an in-depth
example.

.. code:: python
import torch.utils.data
from torchdata.nodes import SamplerWrapper, ParallelMapper, Loader
class SquaredDataset(torch.utils.data.Dataset):
def __getitem__(self, i: int) -> int:
return i**2
def __len__(self):
return 10
dataset = SquaredDataset()
sampler = RandomSampler(dataset)
# For fine-grained control of iteration order, define your own sampler
node = SamplerWrapper(sampler)
# Simply apply dataset's __getitem__ as a map function to the indices generated from sampler
node = ParallelMapper(node, map_fn=dataset.__getitem__, num_workers=3, method="thread")
# Loader is used to convert a node (iterator) into an Iterable that may be reused for multi epochs
loader = Loader(node)
print(list(loader))
# [25, 36, 9, 49, 0, 81, 4, 16, 64, 1]
print(list(loader))
# [0, 4, 1, 64, 49, 25, 9, 16, 81, 36]
9 changes: 8 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,26 @@ Features described in this documentation are classified by release status:
binary distributions like PyPI or Conda, except sometimes behind run-time
flags, and are at an early stage for feedback and testing.

.. toctree::
:maxdepth: 2
:caption: Developer Notes:

what_is_torchdata_nodes.rst

.. toctree::
:maxdepth: 2
:caption: API Reference:

torchdata.stateful_dataloader.rst
torchdata.nodes.rst
torchdata.stateful_dataloader.rst


.. toctree::
:maxdepth: 2
:caption: Tutorial and Examples:

getting_started_with_torchdata_nodes.rst
migrate_to_nodes_from_utils.rst
stateful_dataloader_tutorial.rst


Expand Down
184 changes: 184 additions & 0 deletions docs/source/migrate_to_nodes_from_utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
.. _migrate-to-nodes-from-utils:

Migrating to ``torchdata.nodes`` from ``torch.utils.data``
==========================================================

This guide is intended to help people familiar with ``torch.utils.data``, or
:class:`~torchdata.stateful_dataloader.StatefulDataLoader`,
to get started with ``torchdata.nodes``, and provide a starting ground for defining
your own dataloading pipelines.

We'll demonstrate how to achieve the most common DataLoader features, re-use existing samplers and datasets,
and load/save dataloader state. It performs at least as well as ``DataLoader`` and ``StatefulDataLoader``,
see :ref:`how-does-nodes-perform`.

Map-Style Datasets
~~~~~~~~~~~~~~~~~~

Let's look at the ``DataLoader`` constructor args and go from there

.. code:: python
class DataLoader:
def __init__(
self,
dataset: Dataset[_T_co],
batch_size: Optional[int] = 1,
shuffle: Optional[bool] = None,
sampler: Union[Sampler, Iterable, None] = None,
batch_sampler: Union[Sampler[List], Iterable[List], None] = None,
num_workers: int = 0,
collate_fn: Optional[_collate_fn_t] = None,
pin_memory: bool = False,
drop_last: bool = False,
timeout: float = 0,
worker_init_fn: Optional[_worker_init_fn_t] = None,
multiprocessing_context=None,
generator=None,
*,
prefetch_factor: Optional[int] = None,
persistent_workers: bool = False,
pin_memory_device: str = "",
in_order: bool = True,
):
...
As a referesher, here is roughly how dataloading works in ``torch.utils.data.DataLoader``:
``DataLoader`` begins by generating indices from a ``sampler`` and creates batches of `batch_size` indices.
If no sampler is provided, then a RandomSampler or SequentialSampler is created by default.
The indices are passed to ``Dataset.__getitem__()``, and then a ``collate_fn`` is applied to the batch
of samples. If ``num_workers > 0``, it will use multi-processing to create
subprocesses, and pass the batches of indices to the worker processes, who will then call ``Dataset.__getitem__()`` and apply ``collate_fn``
before returning the batches to the main process. At that point, ``pin_memory`` may be applied to the tensors in the batch.

Now let's look at what an equivalent implementation for DataLoader might look like, built with ``torchdata.nodes``.

.. code:: python
from typing import List, Callable
import torchdata.nodes as tn
from torch.utils.data import RandomSampler, SequentialSampler, default_collate, Dataset
class MapAndCollate:
"""A simple transform that takes a batch of indices, maps with dataset, and then applies
collate.
TODO: make this a standard utility in torchdata.nodes
"""
def __init__(self, dataset, collate_fn):
self.dataset = dataset
self.collate_fn = collate_fn
def __call__(self, batch_of_indices: List[int]):
batch = [self.dataset[i] for i in batch_of_indices]
return self.collate_fn(batch)
# To keep things simple, let's assume that the following args are provided by the caller
def NodesDataLoader(
dataset: Dataset,
batch_size: int,
shuffle: bool,
num_workers: int,
collate_fn: Callable | None,
pin_memory: bool,
drop_last: bool,
):
# Assume we're working with a map-style dataset
assert hasattr(dataset, "__getitem__") and hasattr(dataset, "__len__")
# Start with a sampler, since caller did not provide one
sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset)
# Sampler wrapper converts a Sampler to a BaseNode
node = tn.SamplerWrapper(sampler)
# Now let's batch sampler indices together
node = tn.Batcher(node, batch_size=batch_size, drop_last=drop_last)
# Create a Map Function that accepts a list of indices, applies getitem to it, and
# then collates them
map_and_collate = MapAndCollate(dataset, collate_fn or default_collate)
# MapAndCollate is doing most of the heavy lifting, so let's parallelize it. We could
# choose process or thread workers. Note that if you're not using Free-Threaded
# Python (eg 3.13t) with -Xgil=0, then multi-threading might result in GIL contention,
# and slow down training.
node = tn.ParallelMapper(
node,
map_fn=map_and_collate,
num_workers=num_workers,
method="process", # Set this to "thread" for multi-threading
in_order=True,
)
# Optionally apply pin-memory, and we usually do some pre-fetching
if pin_memory:
node = tn.PinMemory(node)
node = tn.Prefetcher(node, prefetch_factor=num_workers * 2)
# Note that node is an iterator, and once it's exhausted, you'll need to call .reset()
# on it to start a new Epoch.
# Insteaad, we wrap the node in a Loader, which is an iterable and handles reset. It
# also provides state_dict and load_state_dict methods.
return tn.Loader(node)
Now let's test this out with a trivial dataset, and demonstrate how state management works.

.. code:: python
class SquaredDataset(Dataset):
def __init__(self, len: int):
self.len = len
def __len__(self):
return self.len
def __getitem__(self, i: int) -> int:
return i**2
loader = NodesDataLoader(
dataset=SquaredDataset(14),
batch_size=3,
shuffle=False,
num_workers=2,
collate_fn=None,
pin_memory=False,
drop_last=False,
)
batches = []
for idx, batch in enumerate(loader):
if idx == 2:
state_dict = loader.state_dict()
# Saves the state_dict after batch 2 has been returned
batches.append(batch)
loader.load_state_dict(state_dict)
batches_after_loading = list(loader)
print(batches[3:])
# [tensor([ 81, 100, 121]), tensor([144, 169])]
print(batches_after_loading)
# [tensor([ 81, 100, 121]), tensor([144, 169])]
Let's also compare this to torch.utils.data.DataLoader, as a sanity check.

.. code:: python
loaderv1 = torch.utils.data.DataLoader(
dataset=SquaredDataset(14),
batch_size=3,
shuffle=False,
num_workers=2,
collate_fn=None,
pin_memory=False,
drop_last=False,
persistent_workers=False, # Coming soon to torchdata.nodes!
)
print(list(loaderv1))
# [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])]
print(batches)
# [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])]
IterableDatasets
~~~~~~~~~~~~~~~~

Coming soon! While you can already plug your IterableDataset into an ``tn.IterableWrapper``, some functions like
``get_worker_info`` are not currently supported yet. However we believe that often, sharding work between
multi-process workers is not actually necessary, and you can keep some sort of indexing in the main process while
only parallelizing some of the heavier transforms, similar to how Map-style Datasets work above.
Loading

0 comments on commit 5c4c403

Please sign in to comment.