diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 0116b79314..d88c995867 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -244,34 +244,42 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): # Get the lowest rank in the current node local_rank_0 = dist.get_global_rank() - dist.get_local_rank() - for plan_item in plan.items: - relative_file_path = self.storage_data[plan_item.storage_index].relative_path - # Check if the file is scheduled to be downloaded by a lower rank on the same node - # i.e. if rank 0 and rank 1 on the same node have the same the same required file, - # only rank 0 should download it and not rank 1. - is_downloaded = any( - relative_file_path in all_file_paths[i] for i in range(local_rank_0, dist.get_global_rank()) - ) + try: + for plan_item in plan.items: + relative_file_path = self.storage_data[plan_item.storage_index].relative_path + # Check if the file is scheduled to be downloaded by a lower rank on the same node + # i.e. if rank 0 and rank 1 on the same node have the same the same required file, + # only rank 0 should download it and not rank 1. + is_downloaded = any( + relative_file_path in all_file_paths[i] for i in range(local_rank_0, dist.get_global_rank()) + ) - # Download the shard file to the relative path it's associated to and save that relative path - # to the root directory specified to the FileSystem reader constructor. - file_destination = str(Path(self.destination_path) / Path(relative_file_path)) - - # The file could have already been downloaded as different plan items can point to same file. - if not is_downloaded and not os.path.exists(file_destination): - log.debug(f'Downloading {relative_file_path} to {file_destination}.') - object_name = str(Path(self.source_path) / Path(relative_file_path)) - if isinstance(self.object_store, ObjectStore): - self.object_store.download_object( - object_name=object_name, - filename=file_destination, - ) - else: - self.object_store.download_file( - remote_file_name=object_name, - destination=file_destination, - ) - log.debug(f'Finished downloading {relative_file_path} to {file_destination}.') + # Download the shard file to the relative path it's associated to and save that relative path + # to the root directory specified to the FileSystem reader constructor. + file_destination = str(Path(self.destination_path) / Path(relative_file_path)) + + # The file could have already been downloaded as different plan items can point to same file. + if not is_downloaded and not os.path.exists(file_destination): + log.debug(f'Downloading {relative_file_path} to {file_destination}.') + object_name = str(Path(self.source_path) / Path(relative_file_path)) + if isinstance(self.object_store, ObjectStore): + self.object_store.download_object( + object_name=object_name, + filename=file_destination, + ) + else: + self.object_store.download_file( + remote_file_name=object_name, + destination=file_destination, + ) + log.debug(f'Finished downloading {relative_file_path} to {file_destination}.') + except Exception as e: + # PyTorch will capture any exception of this function, + # and dist.all_gather_objects(exception) before raising it. + # If that all_gather_objects fails, the exception is never visible to user. + # We immediately print the exception to avoid that situation. + log.error(f'Exception {type(e)} raised during downloading: {str(e)}') + raise e # 3. Wait for all ranks to finish. log.debug(f'Rank {dist.get_global_rank()} finished downloading all files.')