Skip to content

Commit

Permalink
suppress errors in torchsnapshot
Browse files Browse the repository at this point in the history
Differential Revision: D56154094

fbshipit-source-id: 875fdff3a7d3f30493936c2a04fbed98700f93ac
  • Loading branch information
generatedunixname89002005307016 authored and facebook-github-bot committed Apr 15, 2024
1 parent 8bbf219 commit 706ee13
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 0 deletions.
1 change: 1 addition & 0 deletions tests/gpu_tests/test_dtensor_io_preparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ async def test_dtensor_io_preparer(
"""
Verify the basic behavior of DTensorIOPreparer prepare_write.
"""
# pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[typing.Any],...
device_mesh = DeviceMesh("cuda", mesh=mesh)

if len(placements) > device_mesh.ndim:
Expand Down
1 change: 1 addition & 0 deletions tests/gpu_tests/test_dtensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class TestDTensorUtils(DTensorTestBase):
@skip_if_lt_x_gpu(WORLD_SIZE)
# pyre-fixme[3]: Return type must be annotated.
def test_is_sharded_is_replicated(self):
# pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[typing.Any],...
mesh = DeviceMesh("cuda", mesh=[[0, 1], [2, 3]])
placements = [Replicate(), Shard(0)]
local_tensor = torch.rand((16, 16))
Expand Down
2 changes: 2 additions & 0 deletions tests/test_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,10 @@ def dtensor_test_cases(
dsts = []
for idx in range(NUM_TENSORS):
mesh = (
# pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[typing.A...
DeviceMesh("cuda", mesh=[0])
if use_gpu and idx % 2 == 0
# pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[typing.A...
else DeviceMesh("cpu", mesh=[0])
)
srcs.append(
Expand Down
1 change: 1 addition & 0 deletions tests/test_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def _create_dtensor() -> DTensor:

local_tensor = torch.rand((dim_0, dim_1))

# pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[typing.Any],...
mesh = DeviceMesh("cpu", mesh=[[0, 1], [2, 3]])
placements = [Replicate(), Shard(0)]
dtensor = distribute_tensor(
Expand Down
4 changes: 4 additions & 0 deletions torchsnapshot/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,10 @@ def _dtensor_test_case(
replicated: bool,
) -> Tuple[DTensor, Entry, List[WriteReq]]:
# WORLD_SIZE needs to be at least 4
# pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[typing.Any],
# _NestedSequence[Union[bool, bytes, complex, float, int, str]],
# _NestedSequence[_SupportsArray[typing.Any]], bool, bytes, complex, float, int,
# str, Tensor]` but got `List[List[int]]`.
mesh = DeviceMesh("cuda", mesh=[[0, 1], [2, 3]])
placements = [Replicate(), Shard(0)]
local_tensor = rand_tensor(shape, dtype=dtype)
Expand Down

0 comments on commit 706ee13

Please sign in to comment.