diff --git a/tests/test_manifest.py b/tests/test_manifest.py index 5c8ccb7..b1093a8 100644 --- a/tests/test_manifest.py +++ b/tests/test_manifest.py @@ -725,7 +725,7 @@ def test_replicated_entries_only_on_rank_0(rank: int) -> None: def _update_local_manifest_with_merged_entries( - local_manifest: Dict[str, Entry] + local_manifest: Dict[str, Entry], ) -> None: """ Update the expected local manifest with manually merged ShardedTensorEntries diff --git a/torchsnapshot/asyncio_utils.py b/torchsnapshot/asyncio_utils.py index 0352b59..401b957 100644 --- a/torchsnapshot/asyncio_utils.py +++ b/torchsnapshot/asyncio_utils.py @@ -47,7 +47,9 @@ def _run_once(self): timeout = ( 0 if ready or self._stopping - else min(max(scheduled[0]._when - now, 0), 86400) if scheduled else None + else min(max(scheduled[0]._when - now, 0), 86400) + if scheduled + else None ) event_list = self._selector.select(timeout) self._process_events(event_list) diff --git a/torchsnapshot/io_preparers/tensor.py b/torchsnapshot/io_preparers/tensor.py index 671df32..281ea62 100644 --- a/torchsnapshot/io_preparers/tensor.py +++ b/torchsnapshot/io_preparers/tensor.py @@ -199,7 +199,7 @@ def can_load_inplace( @staticmethod def empty_tensor_from_entry( - entry: Union[TensorEntry, ChunkedTensorEntry] + entry: Union[TensorEntry, ChunkedTensorEntry], ) -> torch.Tensor: if entry.dtype in SUPPORTED_QUANTIZED_DTYPES: # TODO: we can't allocate empty quantized tensors because we don't @@ -394,11 +394,15 @@ def tensor_copy(dst: torch.Tensor, src: torch.Tensor) -> None: # a region of the larger tensor's storage contain data that does not match # the larger tensor's qscheme. - if src.is_quantized and ( - not dst.is_quantized # Copying from quantized Tensor to non-quantized Tensor is not allowed - or dst.qscheme() != src.qscheme() # Quantized copy only works with same qscheme - or dst.dtype != src.dtype # Quantized copy requires matching dtypes - or (dst._is_view() and not _q_params_equal(dst, src)) # See the top comment + if ( + src.is_quantized + and ( + not dst.is_quantized # Copying from quantized Tensor to non-quantized Tensor is not allowed + or dst.qscheme() + != src.qscheme() # Quantized copy only works with same qscheme + or dst.dtype != src.dtype # Quantized copy requires matching dtypes + or (dst._is_view() and not _q_params_equal(dst, src)) # See the top comment + ) ): # TODO: tile the dequantize -> copy to reduce memory footprint src = _tensor_dequantize(src) diff --git a/torchsnapshot/manifest_ops.py b/torchsnapshot/manifest_ops.py index 565fb38..45a81d2 100644 --- a/torchsnapshot/manifest_ops.py +++ b/torchsnapshot/manifest_ops.py @@ -109,7 +109,7 @@ def _get_rank_to_manifest(metadata: SnapshotMetadata) -> List[Dict[str, Entry]]: def _get_merged_sharded_tensor_entries( - rank_to_manifest: List[Dict[str, Entry]] + rank_to_manifest: List[Dict[str, Entry]], ) -> Dict[str, Entry]: groups = defaultdict(list) for manifest in rank_to_manifest: @@ -130,7 +130,7 @@ def _get_merged_sharded_tensor_entries( def _get_merged_dtensor_entries( - rank_to_manifest: List[Dict[str, Entry]] + rank_to_manifest: List[Dict[str, Entry]], ) -> Dict[str, Entry]: """ Merge all DTensor entries across ranks if sharded diff --git a/torchsnapshot/partitioner.py b/torchsnapshot/partitioner.py index 7f8efed..63ba5f0 100644 --- a/torchsnapshot/partitioner.py +++ b/torchsnapshot/partitioner.py @@ -283,7 +283,7 @@ def partition_write_reqs( def _consolidate_replicated_chunked_tensor_entries( - rank_to_entries: List[Dict[str, Entry]] + rank_to_entries: List[Dict[str, Entry]], ) -> List[Dict[str, Entry]]: groups: Dict[str, List[ChunkedTensorEntry]] = defaultdict(list) diff --git a/torchsnapshot/serialization.py b/torchsnapshot/serialization.py index f991466..4202e47 100644 --- a/torchsnapshot/serialization.py +++ b/torchsnapshot/serialization.py @@ -245,8 +245,7 @@ def contiguous_view_as_untyped_storage(tensor: torch.Tensor) -> UntypedStorage: else: untyped_storage = tensor.storage().untyped() return untyped_storage[ - tensor.storage_offset() - * tensor.element_size() : tensor.storage_offset() + tensor.storage_offset() * tensor.element_size() : tensor.storage_offset() * tensor.element_size() + tensor.nelement() * tensor.element_size() ] diff --git a/torchsnapshot/snapshot.py b/torchsnapshot/snapshot.py index d0c0f40..059168a 100644 --- a/torchsnapshot/snapshot.py +++ b/torchsnapshot/snapshot.py @@ -863,7 +863,6 @@ def _coalesce_path_and_replicated( app_state: AppState, replicated: List[str], ) -> Tuple[str, Set[str]]: - rank = pg_wrapper.get_rank() # coalesce path