diff --git a/tests/test_tensor_generators.py b/tests/test_tensor_generators.py index 8128b3d51..7687cd5c4 100644 --- a/tests/test_tensor_generators.py +++ b/tests/test_tensor_generators.py @@ -5,6 +5,7 @@ from collections import defaultdict from ml4cvd.defines import TENSOR_EXT +from ml4cvd.tensor_generators import TensorGenerator from ml4cvd.tensor_generators import _sample_csv_to_set, get_train_valid_test_paths, get_train_valid_test_paths_split_by_csvs @@ -92,6 +93,36 @@ def test_set(request, train_valid_test_csv): return train_valid_test_csv[2] +@pytest.fixture(scope='function') +def train_valid_test_paths(default_arguments, train_valid_test_csv): + args = default_arguments + (train_csv, train_ids), (valid_csv, valid_ids), (test_csv, test_ids) = train_valid_test_csv + return get_train_valid_test_paths( + tensors=args.tensors, + valid_ratio=args.valid_ratio, + test_ratio=args.test_ratio, + sample_csv=None, + train_csv=train_csv, + valid_csv=valid_csv, + test_csv=test_csv, + ) + + +@pytest.fixture(scope='function') +def train_paths(train_valid_test_paths): + return train_valid_test_paths[0] + + +@pytest.fixture(scope='function') +def valid_paths(train_valid_test_paths): + return train_valid_test_paths[1] + + +@pytest.fixture(scope='function') +def test_paths(train_valid_test_paths): + return train_valid_test_paths[2] + + @pytest.fixture(scope='function') def valid_test_ratio(): valid_ratio = np.random.randint(1, 5) / 10 @@ -113,6 +144,43 @@ def test_ratio(request, valid_test_ratio): return valid_test_ratio[1] +class TestTensorGenerator: + + # this test should currently fail - the failure is flaky + def test_get_true_epoch(self, default_arguments, train_paths): + num_workers = 2 + num_tensors = len(train_paths) # 8 paths by default + batch_size = 2 + num_steps = 20 # each path should be visited exactly 5 times + + generator = TensorGenerator( + paths=train_paths, + keep_paths=True, + batch_size=batch_size, + input_maps=default_arguments.tensor_maps_in, + output_maps=default_arguments.tensor_maps_out, + num_workers=num_workers, + cache_size=default_arguments.cache_size, + ) + + rets = [] + for i in range(num_steps): + rets.append(next(generator)) + + paths = [] + for ret in rets: + paths.extend(ret[3]) + unique_paths, counts = np.unique(paths, return_counts=True) + unique_counts = np.unique(counts) + + try: + assert len(unique_counts) == 1 # make sure the tensors visited were seen the same number of times + assert set(unique_paths) == set(train_paths) # make sure all the tensors are visited + assert unique_counts[0] == batch_size * num_steps / num_tensors # make sure the tensors are visited the expected number of times + finally: + del generator + + class TestSampleCsvToSet: def test_sample_csv(self, sample_csv): csv_path, sample_ids = sample_csv @@ -230,9 +298,3 @@ def test_get_paths_overlap(self, default_arguments, train_valid_test_csv): ) # TODO test method with balance csvs - - -class TestTensorGenerator: - - def test_get_true_epoch(self): - pass