Skip to content

Commit

Permalink
add test to enforce infinite buffer size for all applicable datapipes (
Browse files Browse the repository at this point in the history
…#5707)

* add test to enforce infinite buffer size for all applicable datapipes

* use utility function to extract datapipes

* check for buffer_size attr rather than type

* simplify
  • Loading branch information
pmeier authored Mar 30, 2022
1 parent aa21197 commit 93104c1
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions test/test_prototype_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,23 @@
import torch
from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS
from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
from torch.utils.data.graph import traverse
from torchdata.datapipes.iter import IterDataPipe, Shuffler
from torch.utils.data.graph_settings import get_all_graph_pipes
from torchdata.datapipes.iter import IterDataPipe, Shuffler, ShardingFilter
from torchvision._utils import sequence_to_str
from torchvision.prototype import transforms, datasets
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
from torchvision.prototype.features import Image, Label

assert_samples_equal = functools.partial(
assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True
)


def extract_datapipes(dp):
return get_all_graph_pipes(traverse(dp, only_datapipe=True))


@pytest.fixture
def test_home(mocker, tmp_path):
mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path))
Expand Down Expand Up @@ -125,16 +130,12 @@ def test_serializable(self, test_home, dataset_mock, config):
@parametrize_dataset_mocks(DATASET_MOCKS)
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter))
def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type):
def scan(graph):
for node, sub_graph in graph.items():
yield node
yield from scan(sub_graph)

dataset_mock.prepare(test_home, config)

dataset = datasets.load(dataset_mock.name, **config)

if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset))):
if not any(isinstance(dp, annotation_dp_type) for dp in extract_datapipes(dataset)):
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")

@parametrize_dataset_mocks(DATASET_MOCKS)
Expand All @@ -148,6 +149,17 @@ def test_save_load(self, test_home, dataset_mock, config):
buffer.seek(0)
assert_samples_equal(torch.load(buffer), sample)

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_infinite_buffer_size(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)

for dp in extract_datapipes(dataset):
if hasattr(dp, "buffer_size"):
# TODO: replace this with the proper sentinel as soon as https://github.com/pytorch/data/issues/335 is
# resolved
assert dp.buffer_size == INFINITE_BUFFER_SIZE


@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
class TestQMNIST:
Expand Down

0 comments on commit 93104c1

Please sign in to comment.