From b87dacbb9583265e74fa4d8f5e3408514da03fa1 Mon Sep 17 00:00:00 2001 From: Steven Song Date: Thu, 18 Jun 2020 17:07:00 -0400 Subject: [PATCH] fixes in tensor generator (#327) * #260 #324 #326 * get stats q * variable names --- ml4cvd/tensor_generators.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/ml4cvd/tensor_generators.py b/ml4cvd/tensor_generators.py index adaca18f6..01ba66dd1 100755 --- a/ml4cvd/tensor_generators.py +++ b/ml4cvd/tensor_generators.py @@ -90,12 +90,12 @@ def __init__( self._started = False self.workers = [] self.worker_instances = [] + if num_workers == 0: + num_workers = 1 # The one worker is the main thread self.batch_size, self.input_maps, self.output_maps, self.num_workers, self.cache_size, self.weights, self.name, self.keep_paths = \ batch_size, input_maps, output_maps, num_workers, cache_size, weights, name, keep_paths self.true_epochs = 0 self.stats_string = "" - if num_workers == 0: - num_workers = 1 # The one worker is the main thread if weights is None: worker_paths = np.array_split(paths, num_workers) self.true_epoch_lens = list(map(len, worker_paths)) @@ -148,7 +148,7 @@ def _init_workers(self): ) process.start() self.workers.append(process) - logging.info(f"Started {i} {self.name.replace('_', ' ')}s with cache size {self.cache_size/1e9}GB.") + logging.info(f"Started {i + 1} {self.name.replace('_', ' ')}s with cache size {self.cache_size/1e9}GB.") def set_worker_paths(self, paths: List[Path]): """In the single worker case, set the worker's paths.""" @@ -161,11 +161,11 @@ def set_worker_paths(self, paths: List[Path]): def __next__(self) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Optional[List[str]]]: if not self._started: self._init_workers() + if self.stats_q.qsize() == self.num_workers: + self.aggregate_and_print_stats() if self.run_on_main_thread: return next(self.worker_instances[0]) else: - if self.stats_q.qsize() == self.num_workers: - self.aggregate_and_print_stats() return self.q.get(TENSOR_GENERATOR_TIMEOUT) def aggregate_and_print_stats(self): @@ -227,7 +227,7 @@ def aggregate_and_print_stats(self): f"{stats['Tensors presented']:0.0f} tensors were presented.", f"{stats['skipped_paths']} paths were skipped because they previously failed.", f"{error_info}", - f"{self.stats_string}" + f"{self.stats_string}", ]) logging.info(f"\n!!!!>~~~~~~~~~~~~ {self.name} completed true epoch {self.true_epochs} ~~~~~~~~~~~~