Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes in tensor generator #327

Merged
merged 3 commits into from
Jun 18, 2020
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions ml4cvd/tensor_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ def __init__(
self._started = False
self.workers = []
self.worker_instances = []
if num_workers == 0:
num_workers = 1 # The one worker is the main thread
self.batch_size, self.input_maps, self.output_maps, self.num_workers, self.cache_size, self.weights, self.name, self.keep_paths = \
batch_size, input_maps, output_maps, num_workers, cache_size, weights, name, keep_paths
self.true_epochs = 0
self.stats_string = ""
if num_workers == 0:
num_workers = 1 # The one worker is the main thread
if weights is None:
worker_paths = np.array_split(paths, num_workers)
self.true_epoch_lens = list(map(len, worker_paths))
Expand Down Expand Up @@ -148,7 +148,7 @@ def _init_workers(self):
)
process.start()
self.workers.append(process)
logging.info(f"Started {i} {self.name.replace('_', ' ')}s with cache size {self.cache_size/1e9}GB.")
logging.info(f"Started {i + 1} {self.name.replace('_', ' ')}s with cache size {self.cache_size/1e9}GB.")

def set_worker_paths(self, paths: List[Path]):
"""In the single worker case, set the worker's paths."""
Expand All @@ -161,11 +161,11 @@ def set_worker_paths(self, paths: List[Path]):
def __next__(self) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Optional[List[str]]]:
if not self._started:
self._init_workers()
if self.stats_q.qsize() == self.num_workers:
self.aggregate_and_print_stats()
if self.run_on_main_thread:
return next(self.worker_instances[0])
else:
if self.stats_q.qsize() == self.num_workers:
self.aggregate_and_print_stats()
return self.q.get(TENSOR_GENERATOR_TIMEOUT)

def aggregate_and_print_stats(self):
Expand Down Expand Up @@ -227,7 +227,7 @@ def aggregate_and_print_stats(self):
f"{stats['Tensors presented']:0.0f} tensors were presented.",
f"{stats['skipped_paths']} paths were skipped because they previously failed.",
f"{error_info}",
f"{self.stats_string}"
f"{self.stats_string}",
])
logging.info(f"\n!!!!>~~~~~~~~~~~~ {self.name} completed true epoch {self.true_epochs} ~~~~~~~~~~~~<!!!!\nAggregated information string:\n\t{info_string}")

Expand Down Expand Up @@ -725,6 +725,8 @@ def test_train_valid_tensor_generators(
tensors: str,
batch_size: int,
num_workers: int,
training_steps: int,
validation_steps: int,
cache_size: float,
balance_csvs: List[str],
keep_paths: bool = False,
Expand Down Expand Up @@ -786,8 +788,11 @@ def test_train_valid_tensor_generators(
test_csv=test_csv,
)
weights = None
generate_train = TensorGenerator(batch_size, tensor_maps_in, tensor_maps_out, train_paths, num_workers, cache_size, weights, keep_paths, mixup_alpha, name='train_worker', siamese=siamese, augment=True, sample_weight=sample_weight)
generate_valid = TensorGenerator(batch_size, tensor_maps_in, tensor_maps_out, valid_paths, num_workers // 2, cache_size, weights, keep_paths, name='validation_worker', siamese=siamese, augment=False)

train_workers = int(training_steps / (training_steps + validation_steps) * num_workers) or (1 if num_workers else 0)
StevenSong marked this conversation as resolved.
Show resolved Hide resolved
valid_workers = int(validation_steps / (training_steps + validation_steps) * num_workers) or (1 if num_workers else 0)
generate_train = TensorGenerator(batch_size, tensor_maps_in, tensor_maps_out, train_paths, train_workers, cache_size, weights, keep_paths, mixup_alpha, name='train_worker', siamese=siamese, augment=True, sample_weight=sample_weight)
generate_valid = TensorGenerator(batch_size, tensor_maps_in, tensor_maps_out, valid_paths, valid_workers, cache_size, weights, keep_paths, name='validation_worker', siamese=siamese, augment=False)
generate_test = TensorGenerator(batch_size, tensor_maps_in, tensor_maps_out, test_paths, num_workers, 0, weights, keep_paths or keep_paths_test, name='test_worker', siamese=siamese, augment=False)
return generate_train, generate_valid, generate_test

Expand Down