-
Notifications
You must be signed in to change notification settings - Fork 155
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
2024-12-12 nightly release (c0edd90)
- Loading branch information
pytorchbot
committed
Dec 12, 2024
1 parent
19f24de
commit 5c4c403
Showing
16 changed files
with
503 additions
and
258 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
Getting Started With ``torchdata.nodes`` (beta) | ||
=============================================== | ||
|
||
Install torchdata with pip. | ||
|
||
.. code:: bash | ||
pip install torchdata>=0.10.0 | ||
Generator Example | ||
~~~~~~~~~~~~~~~~~ | ||
|
||
Wrap a generator (or any iterable) to convert it to a BaseNode and get started | ||
|
||
.. code:: python | ||
from torchdata.nodes import IterableWrapper, ParallelMapper, Loader | ||
node = IterableWrapper(range(10)) | ||
node = ParallelMapper(node, map_fn=lambda x: x**2, num_workers=3, method="thread") | ||
loader = Loader(node) | ||
result = list(loader) | ||
print(result) | ||
# [0, 1, 4, 9, 16, 25, 36, 49, 64, 81] | ||
Sampler Example | ||
~~~~~~~~~~~~~~~ | ||
|
||
Samplers are still supported, and you can use your existing | ||
``torch.utils.data.Dataset``\'s. See :ref:`migrate-to-nodes-from-utils` for an in-depth | ||
example. | ||
|
||
.. code:: python | ||
import torch.utils.data | ||
from torchdata.nodes import SamplerWrapper, ParallelMapper, Loader | ||
class SquaredDataset(torch.utils.data.Dataset): | ||
def __getitem__(self, i: int) -> int: | ||
return i**2 | ||
def __len__(self): | ||
return 10 | ||
dataset = SquaredDataset() | ||
sampler = RandomSampler(dataset) | ||
# For fine-grained control of iteration order, define your own sampler | ||
node = SamplerWrapper(sampler) | ||
# Simply apply dataset's __getitem__ as a map function to the indices generated from sampler | ||
node = ParallelMapper(node, map_fn=dataset.__getitem__, num_workers=3, method="thread") | ||
# Loader is used to convert a node (iterator) into an Iterable that may be reused for multi epochs | ||
loader = Loader(node) | ||
print(list(loader)) | ||
# [25, 36, 9, 49, 0, 81, 4, 16, 64, 1] | ||
print(list(loader)) | ||
# [0, 4, 1, 64, 49, 25, 9, 16, 81, 36] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
.. _migrate-to-nodes-from-utils: | ||
|
||
Migrating to ``torchdata.nodes`` from ``torch.utils.data`` | ||
========================================================== | ||
|
||
This guide is intended to help people familiar with ``torch.utils.data``, or | ||
:class:`~torchdata.stateful_dataloader.StatefulDataLoader`, | ||
to get started with ``torchdata.nodes``, and provide a starting ground for defining | ||
your own dataloading pipelines. | ||
|
||
We'll demonstrate how to achieve the most common DataLoader features, re-use existing samplers and datasets, | ||
and load/save dataloader state. It performs at least as well as ``DataLoader`` and ``StatefulDataLoader``, | ||
see :ref:`how-does-nodes-perform`. | ||
|
||
Map-Style Datasets | ||
~~~~~~~~~~~~~~~~~~ | ||
|
||
Let's look at the ``DataLoader`` constructor args and go from there | ||
|
||
.. code:: python | ||
class DataLoader: | ||
def __init__( | ||
self, | ||
dataset: Dataset[_T_co], | ||
batch_size: Optional[int] = 1, | ||
shuffle: Optional[bool] = None, | ||
sampler: Union[Sampler, Iterable, None] = None, | ||
batch_sampler: Union[Sampler[List], Iterable[List], None] = None, | ||
num_workers: int = 0, | ||
collate_fn: Optional[_collate_fn_t] = None, | ||
pin_memory: bool = False, | ||
drop_last: bool = False, | ||
timeout: float = 0, | ||
worker_init_fn: Optional[_worker_init_fn_t] = None, | ||
multiprocessing_context=None, | ||
generator=None, | ||
*, | ||
prefetch_factor: Optional[int] = None, | ||
persistent_workers: bool = False, | ||
pin_memory_device: str = "", | ||
in_order: bool = True, | ||
): | ||
... | ||
As a referesher, here is roughly how dataloading works in ``torch.utils.data.DataLoader``: | ||
``DataLoader`` begins by generating indices from a ``sampler`` and creates batches of `batch_size` indices. | ||
If no sampler is provided, then a RandomSampler or SequentialSampler is created by default. | ||
The indices are passed to ``Dataset.__getitem__()``, and then a ``collate_fn`` is applied to the batch | ||
of samples. If ``num_workers > 0``, it will use multi-processing to create | ||
subprocesses, and pass the batches of indices to the worker processes, who will then call ``Dataset.__getitem__()`` and apply ``collate_fn`` | ||
before returning the batches to the main process. At that point, ``pin_memory`` may be applied to the tensors in the batch. | ||
|
||
Now let's look at what an equivalent implementation for DataLoader might look like, built with ``torchdata.nodes``. | ||
|
||
.. code:: python | ||
from typing import List, Callable | ||
import torchdata.nodes as tn | ||
from torch.utils.data import RandomSampler, SequentialSampler, default_collate, Dataset | ||
class MapAndCollate: | ||
"""A simple transform that takes a batch of indices, maps with dataset, and then applies | ||
collate. | ||
TODO: make this a standard utility in torchdata.nodes | ||
""" | ||
def __init__(self, dataset, collate_fn): | ||
self.dataset = dataset | ||
self.collate_fn = collate_fn | ||
def __call__(self, batch_of_indices: List[int]): | ||
batch = [self.dataset[i] for i in batch_of_indices] | ||
return self.collate_fn(batch) | ||
# To keep things simple, let's assume that the following args are provided by the caller | ||
def NodesDataLoader( | ||
dataset: Dataset, | ||
batch_size: int, | ||
shuffle: bool, | ||
num_workers: int, | ||
collate_fn: Callable | None, | ||
pin_memory: bool, | ||
drop_last: bool, | ||
): | ||
# Assume we're working with a map-style dataset | ||
assert hasattr(dataset, "__getitem__") and hasattr(dataset, "__len__") | ||
# Start with a sampler, since caller did not provide one | ||
sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset) | ||
# Sampler wrapper converts a Sampler to a BaseNode | ||
node = tn.SamplerWrapper(sampler) | ||
# Now let's batch sampler indices together | ||
node = tn.Batcher(node, batch_size=batch_size, drop_last=drop_last) | ||
# Create a Map Function that accepts a list of indices, applies getitem to it, and | ||
# then collates them | ||
map_and_collate = MapAndCollate(dataset, collate_fn or default_collate) | ||
# MapAndCollate is doing most of the heavy lifting, so let's parallelize it. We could | ||
# choose process or thread workers. Note that if you're not using Free-Threaded | ||
# Python (eg 3.13t) with -Xgil=0, then multi-threading might result in GIL contention, | ||
# and slow down training. | ||
node = tn.ParallelMapper( | ||
node, | ||
map_fn=map_and_collate, | ||
num_workers=num_workers, | ||
method="process", # Set this to "thread" for multi-threading | ||
in_order=True, | ||
) | ||
# Optionally apply pin-memory, and we usually do some pre-fetching | ||
if pin_memory: | ||
node = tn.PinMemory(node) | ||
node = tn.Prefetcher(node, prefetch_factor=num_workers * 2) | ||
# Note that node is an iterator, and once it's exhausted, you'll need to call .reset() | ||
# on it to start a new Epoch. | ||
# Insteaad, we wrap the node in a Loader, which is an iterable and handles reset. It | ||
# also provides state_dict and load_state_dict methods. | ||
return tn.Loader(node) | ||
Now let's test this out with a trivial dataset, and demonstrate how state management works. | ||
|
||
.. code:: python | ||
class SquaredDataset(Dataset): | ||
def __init__(self, len: int): | ||
self.len = len | ||
def __len__(self): | ||
return self.len | ||
def __getitem__(self, i: int) -> int: | ||
return i**2 | ||
loader = NodesDataLoader( | ||
dataset=SquaredDataset(14), | ||
batch_size=3, | ||
shuffle=False, | ||
num_workers=2, | ||
collate_fn=None, | ||
pin_memory=False, | ||
drop_last=False, | ||
) | ||
batches = [] | ||
for idx, batch in enumerate(loader): | ||
if idx == 2: | ||
state_dict = loader.state_dict() | ||
# Saves the state_dict after batch 2 has been returned | ||
batches.append(batch) | ||
loader.load_state_dict(state_dict) | ||
batches_after_loading = list(loader) | ||
print(batches[3:]) | ||
# [tensor([ 81, 100, 121]), tensor([144, 169])] | ||
print(batches_after_loading) | ||
# [tensor([ 81, 100, 121]), tensor([144, 169])] | ||
Let's also compare this to torch.utils.data.DataLoader, as a sanity check. | ||
|
||
.. code:: python | ||
loaderv1 = torch.utils.data.DataLoader( | ||
dataset=SquaredDataset(14), | ||
batch_size=3, | ||
shuffle=False, | ||
num_workers=2, | ||
collate_fn=None, | ||
pin_memory=False, | ||
drop_last=False, | ||
persistent_workers=False, # Coming soon to torchdata.nodes! | ||
) | ||
print(list(loaderv1)) | ||
# [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])] | ||
print(batches) | ||
# [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])] | ||
IterableDatasets | ||
~~~~~~~~~~~~~~~~ | ||
|
||
Coming soon! While you can already plug your IterableDataset into an ``tn.IterableWrapper``, some functions like | ||
``get_worker_info`` are not currently supported yet. However we believe that often, sharding work between | ||
multi-process workers is not actually necessary, and you can keep some sort of indexing in the main process while | ||
only parallelizing some of the heavier transforms, similar to how Map-style Datasets work above. |
Oops, something went wrong.