From 6f9886ec34d6a825fe334487eb4fba5f4fb91ccb Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 30 Mar 2022 15:34:19 +0200 Subject: [PATCH 1/4] add test to enforce infinite buffer size for all applicable datapipes --- test/test_prototype_builtin_datasets.py | 51 +++++++++++++++++++++---- 1 file changed, 44 insertions(+), 7 deletions(-) diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index b4eed473cdc..53a307eae7d 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -7,11 +7,21 @@ import torch from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair -from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter from torch.utils.data.graph import traverse -from torchdata.datapipes.iter import IterDataPipe, Shuffler +from torchdata.datapipes.iter import ( + IterDataPipe, + Shuffler, + ShardingFilter, + Demultiplexer, + Forker, + Grouper, + MaxTokenBucketizer, + UnZipper, + IterKeyZipper, +) from torchvision._utils import sequence_to_str from torchvision.prototype import transforms, datasets +from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE from torchvision.prototype.features import Image, Label assert_samples_equal = functools.partial( @@ -35,6 +45,15 @@ def test_coverage(): ) +def extract_datapipes(dp): + def scan(graph): + for node, sub_graph in graph.items(): + yield node + yield from scan(sub_graph) + + yield from scan(traverse(dp)) + + @pytest.mark.filterwarnings("error") class TestCommon: @parametrize_dataset_mocks(DATASET_MOCKS) @@ -125,16 +144,12 @@ def test_serializable(self, test_home, dataset_mock, config): @parametrize_dataset_mocks(DATASET_MOCKS) @pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter)) def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type): - def scan(graph): - for node, sub_graph in graph.items(): - yield node - yield from scan(sub_graph) dataset_mock.prepare(test_home, config) dataset = datasets.load(dataset_mock.name, **config) - if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset))): + if not any(isinstance(dp, annotation_dp_type) for dp in extract_datapipes(dataset)): raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.") @parametrize_dataset_mocks(DATASET_MOCKS) @@ -148,6 +163,28 @@ def test_save_load(self, test_home, dataset_mock, config): buffer.seek(0) assert_samples_equal(torch.load(buffer), sample) + @parametrize_dataset_mocks(DATASET_MOCKS) + def test_infinite_buffer_size(self, test_home, dataset_mock, config): + dataset_mock.prepare(test_home, config) + dataset = datasets.load(dataset_mock.name, **config) + + for dp in extract_datapipes(dataset): + if isinstance( + dp, + ( + Shuffler, + Demultiplexer, + Forker, + Grouper, + MaxTokenBucketizer, + UnZipper, + IterKeyZipper, + ), + ): + # TODO: replace this with the proper sentinel as soon as https://github.com/pytorch/data/issues/335 is + # resolved + assert dp.buffer_size == INFINITE_BUFFER_SIZE + @parametrize_dataset_mocks(DATASET_MOCKS["qmnist"]) class TestQMNIST: From eca430bba21929855347513bd6b1470c57195c2e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 30 Mar 2022 16:05:25 +0200 Subject: [PATCH 2/4] use utility function to extract datapipes --- test/test_prototype_builtin_datasets.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 53a307eae7d..5c9b971268c 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -8,6 +8,7 @@ from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair from torch.utils.data.graph import traverse +from torch.utils.data.graph_settings import get_all_graph_pipes from torchdata.datapipes.iter import ( IterDataPipe, Shuffler, @@ -29,6 +30,10 @@ ) +def extract_datapipes(dp): + return get_all_graph_pipes(traverse(dp, only_datapipe=True)) + + @pytest.fixture def test_home(mocker, tmp_path): mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path)) @@ -45,15 +50,6 @@ def test_coverage(): ) -def extract_datapipes(dp): - def scan(graph): - for node, sub_graph in graph.items(): - yield node - yield from scan(sub_graph) - - yield from scan(traverse(dp)) - - @pytest.mark.filterwarnings("error") class TestCommon: @parametrize_dataset_mocks(DATASET_MOCKS) From 68b13a956503a04e550b270f5ed94bffb3a89687 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 30 Mar 2022 16:07:40 +0200 Subject: [PATCH 3/4] check for buffer_size attr rather than type --- test/test_prototype_builtin_datasets.py | 35 +++++++------------------ 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 5c9b971268c..56d5165e32c 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -9,17 +9,7 @@ from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair from torch.utils.data.graph import traverse from torch.utils.data.graph_settings import get_all_graph_pipes -from torchdata.datapipes.iter import ( - IterDataPipe, - Shuffler, - ShardingFilter, - Demultiplexer, - Forker, - Grouper, - MaxTokenBucketizer, - UnZipper, - IterKeyZipper, -) +from torchdata.datapipes.iter import IterDataPipe, Shuffler, ShardingFilter from torchvision._utils import sequence_to_str from torchvision.prototype import transforms, datasets from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE @@ -165,21 +155,14 @@ def test_infinite_buffer_size(self, test_home, dataset_mock, config): dataset = datasets.load(dataset_mock.name, **config) for dp in extract_datapipes(dataset): - if isinstance( - dp, - ( - Shuffler, - Demultiplexer, - Forker, - Grouper, - MaxTokenBucketizer, - UnZipper, - IterKeyZipper, - ), - ): - # TODO: replace this with the proper sentinel as soon as https://github.com/pytorch/data/issues/335 is - # resolved - assert dp.buffer_size == INFINITE_BUFFER_SIZE + try: + buffer_size = getattr(dp, "buffer_size") + except AttributeError: + continue + + # TODO: replace this with the proper sentinel as soon as https://github.com/pytorch/data/issues/335 is + # resolved + assert buffer_size == INFINITE_BUFFER_SIZE @parametrize_dataset_mocks(DATASET_MOCKS["qmnist"]) From 7e36fd97581056058b4827a10ef824036c3b5322 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 30 Mar 2022 16:16:14 +0200 Subject: [PATCH 4/4] simplify --- test/test_prototype_builtin_datasets.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 56d5165e32c..8d51125f41c 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -155,14 +155,10 @@ def test_infinite_buffer_size(self, test_home, dataset_mock, config): dataset = datasets.load(dataset_mock.name, **config) for dp in extract_datapipes(dataset): - try: - buffer_size = getattr(dp, "buffer_size") - except AttributeError: - continue - - # TODO: replace this with the proper sentinel as soon as https://github.com/pytorch/data/issues/335 is - # resolved - assert buffer_size == INFINITE_BUFFER_SIZE + if hasattr(dp, "buffer_size"): + # TODO: replace this with the proper sentinel as soon as https://github.com/pytorch/data/issues/335 is + # resolved + assert dp.buffer_size == INFINITE_BUFFER_SIZE @parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])