Skip to content

Commit

Permalink
Attempt single blob load
Browse files Browse the repository at this point in the history
  • Loading branch information
daviswer committed Jan 14, 2025
1 parent d2eb12e commit ada91ec
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions torchdata/stateful_dataloader/ibm_rescalable.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,23 +587,24 @@ def load_distributed_state_dict(
nworkers = base["_snapshot"]["_main_snapshot"]["_num_workers"]
rank = loader.dataset.rank
dstate = __pop_dstate(base, device_mesh, [dtensor.placement_types.Shard(0)]) # placements)
inp = {"state":base, "dstate":dstate}
# Read nondistributed state dict
ckp_ws = 0 if not os.path.exists(path) else len([x for x in os.listdir(path) if "__nondist_cp_" in x])
ckp_ws = 0 if not os.path.exists(path) else len(os.listdir(path))
# Read distributed state dict
reader = checkpoint.FileSystemReader(path)
checkpoint.load_state_dict(
inp,
reader,
)
dstate = inp["dstate"]
# Check that number of loaders matches
if ckp_ws == loader.dataset.worldsize:
state = torch.load(os.path.join(path, f"__nondist_cp_{rank}.pth"))
# Check that number of workers matches
if nworkers != state["_snapshot"]["_main_snapshot"]["_num_workers"]:
state = base
state = inp["state"]
else:
# On mismatch, discard saved non-reshardable loader state and start fresh
state = base
# Read distributed state dict
reader = checkpoint.FileSystemReader(path)
checkpoint.load_state_dict(
dstate,
reader,
)
# Get local tensors from dtensors, and slice over workers
dstate = {k: v.to_local().chunk(nworkers) for k, v in dstate.items()}
# Flip dict[list[tensor]] to list[dict[tensor]]
Expand Down

0 comments on commit ada91ec

Please sign in to comment.