Skip to content

Commit

Permalink
Explicitly print checkpoint downloading exception (mosaicml#3131)
Browse files Browse the repository at this point in the history
* a

* up

* up

* up

* up

* Update composer/utils/checkpoint.py

Co-authored-by: Mihir Patel <[email protected]>

---------

Co-authored-by: Mihir Patel <[email protected]>
  • Loading branch information
bigning and mvpatel2000 authored Mar 21, 2024
1 parent fd3202d commit cf031e2
Showing 1 changed file with 35 additions and 27 deletions.
62 changes: 35 additions & 27 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,34 +244,42 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
# Get the lowest rank in the current node
local_rank_0 = dist.get_global_rank() - dist.get_local_rank()

for plan_item in plan.items:
relative_file_path = self.storage_data[plan_item.storage_index].relative_path
# Check if the file is scheduled to be downloaded by a lower rank on the same node
# i.e. if rank 0 and rank 1 on the same node have the same the same required file,
# only rank 0 should download it and not rank 1.
is_downloaded = any(
relative_file_path in all_file_paths[i] for i in range(local_rank_0, dist.get_global_rank())
)
try:
for plan_item in plan.items:
relative_file_path = self.storage_data[plan_item.storage_index].relative_path
# Check if the file is scheduled to be downloaded by a lower rank on the same node
# i.e. if rank 0 and rank 1 on the same node have the same the same required file,
# only rank 0 should download it and not rank 1.
is_downloaded = any(
relative_file_path in all_file_paths[i] for i in range(local_rank_0, dist.get_global_rank())
)

# Download the shard file to the relative path it's associated to and save that relative path
# to the root directory specified to the FileSystem reader constructor.
file_destination = str(Path(self.destination_path) / Path(relative_file_path))

# The file could have already been downloaded as different plan items can point to same file.
if not is_downloaded and not os.path.exists(file_destination):
log.debug(f'Downloading {relative_file_path} to {file_destination}.')
object_name = str(Path(self.source_path) / Path(relative_file_path))
if isinstance(self.object_store, ObjectStore):
self.object_store.download_object(
object_name=object_name,
filename=file_destination,
)
else:
self.object_store.download_file(
remote_file_name=object_name,
destination=file_destination,
)
log.debug(f'Finished downloading {relative_file_path} to {file_destination}.')
# Download the shard file to the relative path it's associated to and save that relative path
# to the root directory specified to the FileSystem reader constructor.
file_destination = str(Path(self.destination_path) / Path(relative_file_path))

# The file could have already been downloaded as different plan items can point to same file.
if not is_downloaded and not os.path.exists(file_destination):
log.debug(f'Downloading {relative_file_path} to {file_destination}.')
object_name = str(Path(self.source_path) / Path(relative_file_path))
if isinstance(self.object_store, ObjectStore):
self.object_store.download_object(
object_name=object_name,
filename=file_destination,
)
else:
self.object_store.download_file(
remote_file_name=object_name,
destination=file_destination,
)
log.debug(f'Finished downloading {relative_file_path} to {file_destination}.')
except Exception as e:
# PyTorch will capture any exception of this function,
# and dist.all_gather_objects(exception) before raising it.
# If that all_gather_objects fails, the exception is never visible to user.
# We immediately print the exception to avoid that situation.
log.error(f'Exception {type(e)} raised during downloading: {str(e)}')
raise e

# 3. Wait for all ranks to finish.
log.debug(f'Rank {dist.get_global_rank()} finished downloading all files.')
Expand Down

0 comments on commit cf031e2

Please sign in to comment.