Skip to content

Commit

Permalink
test for getting true epoch #355
Browse files Browse the repository at this point in the history
  • Loading branch information
StevenSong committed Jul 6, 2020
1 parent 9736e29 commit 496647e
Showing 1 changed file with 68 additions and 6 deletions.
74 changes: 68 additions & 6 deletions tests/test_tensor_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import defaultdict

from ml4cvd.defines import TENSOR_EXT
from ml4cvd.tensor_generators import TensorGenerator
from ml4cvd.tensor_generators import _sample_csv_to_set, get_train_valid_test_paths, get_train_valid_test_paths_split_by_csvs


Expand Down Expand Up @@ -92,6 +93,36 @@ def test_set(request, train_valid_test_csv):
return train_valid_test_csv[2]


@pytest.fixture(scope='function')
def train_valid_test_paths(default_arguments, train_valid_test_csv):
args = default_arguments
(train_csv, train_ids), (valid_csv, valid_ids), (test_csv, test_ids) = train_valid_test_csv
return get_train_valid_test_paths(
tensors=args.tensors,
valid_ratio=args.valid_ratio,
test_ratio=args.test_ratio,
sample_csv=None,
train_csv=train_csv,
valid_csv=valid_csv,
test_csv=test_csv,
)


@pytest.fixture(scope='function')
def train_paths(train_valid_test_paths):
return train_valid_test_paths[0]


@pytest.fixture(scope='function')
def valid_paths(train_valid_test_paths):
return train_valid_test_paths[1]


@pytest.fixture(scope='function')
def test_paths(train_valid_test_paths):
return train_valid_test_paths[2]


@pytest.fixture(scope='function')
def valid_test_ratio():
valid_ratio = np.random.randint(1, 5) / 10
Expand All @@ -113,6 +144,43 @@ def test_ratio(request, valid_test_ratio):
return valid_test_ratio[1]


class TestTensorGenerator:

# this test should currently fail - the failure is flaky
def test_get_true_epoch(self, default_arguments, train_paths):
num_workers = 2
num_tensors = len(train_paths) # 8 paths by default
batch_size = 2
num_steps = 20 # each path should be visited exactly 5 times

generator = TensorGenerator(
paths=train_paths,
keep_paths=True,
batch_size=batch_size,
input_maps=default_arguments.tensor_maps_in,
output_maps=default_arguments.tensor_maps_out,
num_workers=num_workers,
cache_size=default_arguments.cache_size,
)

rets = []
for i in range(num_steps):
rets.append(next(generator))

paths = []
for ret in rets:
paths.extend(ret[3])
unique_paths, counts = np.unique(paths, return_counts=True)
unique_counts = np.unique(counts)

try:
assert len(unique_counts) == 1 # make sure the tensors visited were seen the same number of times
assert set(unique_paths) == set(train_paths) # make sure all the tensors are visited
assert unique_counts[0] == batch_size * num_steps / num_tensors # make sure the tensors are visited the expected number of times
finally:
del generator


class TestSampleCsvToSet:
def test_sample_csv(self, sample_csv):
csv_path, sample_ids = sample_csv
Expand Down Expand Up @@ -230,9 +298,3 @@ def test_get_paths_overlap(self, default_arguments, train_valid_test_csv):
)

# TODO test method with balance csvs


class TestTensorGenerator:

def test_get_true_epoch(self):
pass

0 comments on commit 496647e

Please sign in to comment.