diff --git a/tests/test_sharded_state_loader.py b/tests/test_sharded_state_loader.py index 022fb36b346f4..de79c3b945d4d 100644 --- a/tests/test_sharded_state_loader.py +++ b/tests/test_sharded_state_loader.py @@ -39,7 +39,8 @@ def test_filter_subtensors(): filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict) assert tuple(filtered_state_dict.keys()) == ("a", "b", "c") for key, tensor in filtered_state_dict.items(): - assert tensor.equal(state_dict[key]) + # NOTE: don't use `euqal` here, as the tensor might contain NaNs + assert tensor is state_dict[key] @pytest.fixture(scope="module")