-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update torchdata.nodes docs, use sphinx for API #1396
Merged
Merged
Changes from 5 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
e3943d9
update docstrings for sphinx
e97e00f
update docstrings for sphinx
473ebc0
add migration from torch.utils.data
644c6d4
add performance section
f389db0
add xref to performance in migrate guide
054e6a6
minor pr comments
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
Getting Started With ``torchdata.nodes`` (beta) | ||
=============================================== | ||
|
||
Install torchdata with pip. | ||
|
||
.. code:: bash | ||
|
||
pip install torchdata>=0.10.0 | ||
|
||
Generator Example | ||
~~~~~~~~~~~~~~~~~ | ||
|
||
Wrap a generator (or any iterable) to convert it to a BaseNode and get started | ||
|
||
.. code:: python | ||
|
||
from torchdata.nodes import IterableWrapper, ParallelMapper, Loader | ||
|
||
node = IterableWrapper(range(10)) | ||
node = ParallelMapper(node, map_fn=lambda x: x**2, num_workers=3, method="thread") | ||
loader = Loader(node) | ||
result = list(loader) | ||
print(result) | ||
# [0, 1, 4, 9, 16, 25, 36, 49, 64, 81] | ||
|
||
Sampler Example | ||
~~~~~~~~~~~~~~~ | ||
|
||
Samplers are still supported, and you can use your existing | ||
``torch.utils.data.Dataset``\'s. See :ref:`migrate-to-nodes-from-utils` for an in-depth | ||
example. | ||
|
||
.. code:: python | ||
|
||
import torch.utils.data | ||
from torchdata.nodes import SamplerWrapper, ParallelMapper, Loader | ||
|
||
|
||
class SquaredDataset(torch.utils.data.Dataset): | ||
def __getitem__(self, i: int) -> int: | ||
return i**2 | ||
def __len__(self): | ||
return 10 | ||
|
||
dataset = SquaredDataset() | ||
sampler = RandomSampler(dataset) | ||
|
||
# For fine-grained control of iteration order, define your own sampler | ||
node = SamplerWrapper(sampler) | ||
# Simply apply dataset's __getitem__ as a map function to the indices generated from sampler | ||
node = ParallelMapper(node, map_fn=dataset.__getitem__, num_workers=3, method="thread") | ||
# Loader is used to convert a node (iterator) into an Iterable that may be reused for multi epochs | ||
loader = Loader(node) | ||
print(list(loader)) | ||
# [25, 36, 9, 49, 0, 81, 4, 16, 64, 1] | ||
print(list(loader)) | ||
# [0, 4, 1, 64, 49, 25, 9, 16, 81, 36] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
.. _migrate-to-nodes-from-utils: | ||
|
||
Migrating to ``torchdata.nodes`` from ``torch.utils.data`` | ||
========================================================== | ||
|
||
This guide is intended to help people familiar with ``torch.utils.data``, or | ||
:class:`~torchdata.stateful_dataloader.StatefulDataLoader`, | ||
to get started with ``torchdata.nodes``, and provide a starting ground for defining | ||
your own dataloading pipelines. | ||
|
||
We'll demonstrate how to achieve the most common DataLoader features, re-use existing samplers and datasets, | ||
and load/save dataloader state. It performs at least as well as ``DataLoader`` and ``StatefulDataLoader``, | ||
see :ref:`how-does-nodes-perform`. | ||
|
||
Map-Style Datasets | ||
~~~~~~~~~~~~~~~~~~ | ||
|
||
Let's look at the ``DataLoader`` constructor args and go from there | ||
|
||
.. code:: python | ||
|
||
class DataLoader: | ||
def __init__( | ||
self, | ||
dataset: Dataset[_T_co], | ||
batch_size: Optional[int] = 1, | ||
shuffle: Optional[bool] = None, | ||
sampler: Union[Sampler, Iterable, None] = None, | ||
batch_sampler: Union[Sampler[List], Iterable[List], None] = None, | ||
num_workers: int = 0, | ||
collate_fn: Optional[_collate_fn_t] = None, | ||
pin_memory: bool = False, | ||
drop_last: bool = False, | ||
timeout: float = 0, | ||
worker_init_fn: Optional[_worker_init_fn_t] = None, | ||
multiprocessing_context=None, | ||
generator=None, | ||
*, | ||
prefetch_factor: Optional[int] = None, | ||
persistent_workers: bool = False, | ||
pin_memory_device: str = "", | ||
in_order: bool = True, | ||
): | ||
... | ||
|
||
As a referesher, here is roughly how dataloading works in ``torch.utils.data.DataLoader``: | ||
``DataLoader`` begins by generating indices from a ``sampler`` and creates batches of `batch_size` indices. | ||
If no sampler is provided, then a RandomSampler or SequentialSampler is created by default. | ||
The indices are passed to ``Dataset.__getitem__()``, and then a ``collate_fn`` is applied to the batch | ||
of samples. If ``num_workers > 0``, it will use multi-processing to create | ||
subprocesses, and pass the batches of indices to the worker processes, who will then call ``Dataset.__getitem__()`` and apply ``collate_fn`` | ||
before returning the batches to the main process. At that point, ``pin_memory`` may be applied to the tensors in the batch. | ||
|
||
Now let's look at what an equivalent implementation for DataLoader might look like, built with ``torchdata.nodes``. | ||
|
||
.. code:: python | ||
|
||
from typing import List, Callable | ||
import torchdata.nodes as tn | ||
from torch.utils.data import RandomSampler, SequentialSampler, default_collate, Dataset | ||
|
||
class MapAndCollate: | ||
"""A simple transform that takes a batch of indices, maps with dataset, and then applies | ||
collate. | ||
TODO: make this a standard utility in torchdata.nodes | ||
""" | ||
def __init__(self, dataset, collate_fn): | ||
self.dataset = dataset | ||
self.collate_fn = collate_fn | ||
|
||
def __call__(self, batch_of_indices: List[int]): | ||
batch = [self.dataset[i] for i in batch_of_indices] | ||
return self.collate_fn(batch) | ||
|
||
# To keep things simple, let's assume that the following args are provided by the caller | ||
def NodesDataLoader( | ||
dataset: Dataset, | ||
batch_size: int, | ||
shuffle: bool, | ||
num_workers: int, | ||
collate_fn: Callable | None, | ||
pin_memory: bool, | ||
drop_last: bool, | ||
): | ||
# Assume we're working with a map-style dataset | ||
assert hasattr(dataset, "__getitem__") and hasattr(dataset, "__len__") | ||
# Start with a sampler, since caller did not provide one | ||
sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset) | ||
# Sampler wrapper converts a Sampler to a BaseNode | ||
node = tn.SamplerWrapper(sampler) | ||
|
||
# Now let's batch sampler indices together | ||
node = tn.Batcher(node, batch_size=batch_size, drop_last=drop_last) | ||
|
||
# Create a Map Function that accepts a list of indices, applies getitem to it, and | ||
# then collates them | ||
map_and_collate = MapAndCollate(dataset, collate_fn or default_collate) | ||
|
||
# MapAndCollate is doing most of the heavy lifting, so let's parallelize it. We could | ||
# choose process or thread workers. Note that if you're not using Free-Threaded | ||
# Python (eg 3.13t) with -Xgil=0, then multi-threading might result in GIL contention, | ||
# and slow down training. | ||
node = tn.ParallelMapper( | ||
node, | ||
map_fn=map_and_collate, | ||
num_workers=num_workers, | ||
method="process", # Set this to "thread" for multi-threading | ||
in_order=True, | ||
) | ||
|
||
# Optionally apply pin-memory, and we usually do some pre-fetching | ||
if pin_memory: | ||
node = tn.PinMemory(node) | ||
node = tn.Prefetcher(node, prefetch_factor=num_workers * 2) | ||
|
||
# Note that node is an iterator, and once it's exhausted, you'll need to call .reset() | ||
# on it to start a new Epoch. | ||
# Insteaad, we wrap the node in a Loader, which is an iterable and handles reset. It | ||
# also provides state_dict and load_state_dict methods. | ||
return tn.Loader(node) | ||
|
||
Now let's test this out with a useless dataset, and demonstrate how state management works. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: replace |
||
|
||
.. code:: python | ||
|
||
class SquaredDataset(Dataset): | ||
def __init__(self, len: int): | ||
self.len = len | ||
def __len__(self): | ||
return self.len | ||
def __getitem__(self, i: int) -> int: | ||
return i**2 | ||
|
||
loader = NodesDataLoader( | ||
dataset=SquaredDataset(14), | ||
batch_size=3, | ||
shuffle=False, | ||
num_workers=2, | ||
collate_fn=None, | ||
pin_memory=False, | ||
drop_last=False, | ||
) | ||
|
||
batches = [] | ||
for idx, batch in enumerate(loader): | ||
if idx == 2: | ||
state_dict = loader.state_dict() | ||
# Saves the state_dict after batch 2 has been returned | ||
batches.append(batch) | ||
|
||
loader.load_state_dict(state_dict) | ||
batches_after_loading = list(loader) | ||
print(batches[3:]) | ||
# [tensor([ 81, 100, 121]), tensor([144, 169])] | ||
print(batches_after_loading) | ||
# [tensor([ 81, 100, 121]), tensor([144, 169])] | ||
|
||
Let's also compare this to torch.utils.data.DataLoader, as a sanity check. | ||
|
||
.. code:: python | ||
|
||
loaderv1 = torch.utils.data.DataLoader( | ||
dataset=SquaredDataset(14), | ||
batch_size=3, | ||
shuffle=False, | ||
num_workers=2, | ||
collate_fn=None, | ||
pin_memory=False, | ||
drop_last=False, | ||
persistent_workers=False, # Coming soon to torchdata.nodes! | ||
) | ||
print(list(loaderv1)) | ||
# [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])] | ||
print(batches) | ||
# [tensor([0, 1, 4]), tensor([ 9, 16, 25]), tensor([36, 49, 64]), tensor([ 81, 100, 121]), tensor([144, 169])] | ||
|
||
|
||
IterableDatasets | ||
~~~~~~~~~~~~~~~~ | ||
|
||
Coming soon! While you can already plug your IterableDataset into an ``tn.IterableWrapper``, some functions like | ||
``get_worker_info`` are not currently supported yet. However we believe that often, sharding work between | ||
multi-process workers is not actually necessary, and you can keep some sort of indexing in the main process while | ||
only parallelizing some of the heavier transforms, similar to how Map-style Datasets work above. |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@divyanshk added xref here