diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index 2e5fd5d07b..3b4f26024c 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -316,6 +316,8 @@ def test_fsdp_full_state_dict_load( ): if use_hsdp: pytest.xfail('Known Pytorch issue with HSDP, waiting for pytorch patch') + if (use_tp or use_hsdp) and version.parse(torch.__version__) < version.parse('2.3.0'): + pytest.skip('HSDP and TP require torch 2.3.0 or later') if autoresume: run_name = 'my-cool-autoresume-run' else: @@ -833,8 +835,8 @@ def test_fsdp_partitioned_state_dict_load( ): if weights_only and autoresume: pytest.skip('Weights only with autoresume is not supported') - if use_tp and version.parse(torch.__version__) < version.parse('2.3.0'): - pytest.skip('TP requires torch 2.3.0 or later') + if (use_tp or use_hsdp) and version.parse(torch.__version__) < version.parse('2.3.0'): + pytest.skip('HSDP and TP require torch 2.3.0 or later') load_ignore_keys = [] if load_ignore_keys is None else load_ignore_keys