Skip to content

Commit

Permalink
Add torch 2.3 CI/CD (mosaicml#3211)
Browse files Browse the repository at this point in the history
* add tests

* fix monkeypatch

* remove api fallback

* fix tests

* fix
  • Loading branch information
mvpatel2000 authored Apr 25, 2024
1 parent 8002624 commit dcf8828
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 23 deletions.
28 changes: 19 additions & 9 deletions .github/workflows/daily.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,18 @@ jobs:
markers: not daily and (remote or not remote) and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: mosaicml
- name: cpu-3.10-2.1-composer
container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
- name: cpu-3.11-2.2
container: mosaicml/pytorch:2.2.1_cpu-python3.11-ubuntu20.04
markers: not daily and (remote or not remote) and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: mosaicml
- name: cpu-3.11-2.2-composer
container: mosaicml/pytorch:2.2.1_cpu-python3.10-ubuntu20.04
markers: not daily and (remote or not remote) and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: composer
- name: cpu-3.11-2.2
container: mosaicml/pytorch:2.2.1_cpu-python3.11-ubuntu20.04
- name: cpu-3.11-2.3
container: mosaicml/pytorch:2.3.0_cpu-python3.11-ubuntu20.04
markers: not daily and (remote or not remote) and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: mosaicml
Expand All @@ -43,16 +48,21 @@ jobs:
markers: daily and (remote or not remote) and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: mosaicml
- name: daily-cpu-3.10-2.1-composer
container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
markers: daily and (remote or not remote) and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: composer
- name: daily-cpu-3.11-2.2
container: mosaicml/pytorch:2.2.1_cpu-python3.11-ubuntu20.04
markers: daily and (remote or not remote) and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: mosaicml
- name: daily-cpu-3.11-2.2-composer
container: mosaicml/pytorch:2.2.1_cpu-python3.10-ubuntu20.04
markers: daily and (remote or not remote) and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: composer
- name: daily-cpu-3.11-2.3-composer
container: mosaicml/pytorch:2.3.0_cpu-python3.11-ubuntu20.04
markers: daily and (remote or not remote) and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: composer
- name: daily-cpu-doctest
container: mosaicml/pytorch:2.2.1_cpu-python3.11-ubuntu20.04
markers: daily and (remote or not remote) and not gpu and doctest
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/pr-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ jobs:
container: mosaicml/pytorch:2.2.1_cpu-python3.11-ubuntu20.04
markers: not daily and not remote and not gpu and not doctest
pytest_command: coverage run -m pytest
- name: cpu-3.11-2.3
container: mosaicml/pytorch:2.3.0_cu121-python3.11-ubuntu20.04
markers: not daily and not remote and not gpu and not doctest
pytest_command: coverage run -m pytest
- name: cpu-doctest
container: mosaicml/pytorch:2.2.1_cpu-python3.11-ubuntu20.04
markers: not daily and not remote and not gpu and doctest
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/pr-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ jobs:
strategy:
matrix:
include:
- name: gpu-3.11-2.2
container: mosaicml/pytorch:2.2.1_cu121-python3.10-ubuntu20.04
- name: gpu-3.11-2.3
container: mosaicml/pytorch:2.3.0_cu121-python3.11-ubuntu20.04
markers: not daily and not remote and gpu and (doctest or not doctest)
pytest_command: coverage run -m pytest
composer_package_name: mosaicml
Expand Down
14 changes: 7 additions & 7 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def get_model_state_dict(self) -> Dict[str, Any]:
Returns:
Dict[str, Any]: The state dict for the model.
"""
if version.parse(torch.__version__) >= version.parse('2.3.0'):
if version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized():
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
raise NotImplementedError(
Expand All @@ -899,9 +899,9 @@ def get_model_state_dict(self) -> Dict[str, Any]:
else:
model_state_dict = self.model.state_dict()

# If model is DDP wrapped, do not save the `module.` prefix, as that is an implementation detail
if self.is_model_ddp:
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(model_state_dict, 'module.')
# If model is DDP wrapped, do not save the `module.` prefix, as that is an implementation detail
if self.is_model_ddp:
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(model_state_dict, 'module.')

return model_state_dict

Expand All @@ -911,7 +911,7 @@ def get_optim_state_dict(self) -> Dict[str, Any]:
Returns:
Dict[str, Any]: The state dict for the optimizer.
"""
if version.parse(torch.__version__) >= version.parse('2.3.0'):
if version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized():
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict
if self.fsdp_state_dict_type not in [None, 'full', 'sharded']:
raise NotImplementedError(
Expand Down Expand Up @@ -1233,7 +1233,7 @@ def load_model_state(
model_on_rank = state_dict['model'] is not None

if model_on_rank:
if version.parse(torch.__version__) >= version.parse('2.3.0'):
if version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized():
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
set_model_state_dict(
model=self.model,
Expand Down Expand Up @@ -1297,7 +1297,7 @@ def load_optim_state(self, state_dict: Dict[str, Any], strict: bool = True):
strict (bool): Whether the keys (i.e., optimizer parameter names) in the optimizer
state dict should perfectly match the keys in the optimizer instance.
"""
if version.parse(torch.__version__) >= version.parse('2.3.0'):
if version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized():
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_optimizer_state_dict
optimizer = self.optimizers[0]
set_optimizer_state_dict(
Expand Down
2 changes: 2 additions & 0 deletions composer/trainer/mosaic_fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,8 @@ def _shard_orig_param_state(
if version.parse(torch.__version__) >= version.parse('2.3.0') and version.parse(
torch.__version__,
) < version.parse('2.3.1'):
from torch.distributed._tensor import DTensor

@no_type_check
def _same_storage(a, b):
if isinstance(a, DTensor):
Expand Down
3 changes: 2 additions & 1 deletion tests/algorithms/test_required_on_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from composer.callbacks import CheckpointSaver
from composer.core import Algorithm, Event, Time, TimeUnit # type: ignore imports used in `eval(representation)`
from composer.models import ComposerClassifier, ComposerModel
from composer.utils import dist
from tests.common import ConvModel, SimpleConvModel, composer_resnet


Expand Down Expand Up @@ -173,7 +174,7 @@ def test_autoload(
context = pytest.warns(UserWarning, match='Automatically adding required_on_load algorithm*')
# Excluding some algorithms leads to errors when loading
elif exclude:
if version.parse(torch.__version__) > version.parse('2.2.9'):
if version.parse(torch.__version__) >= version.parse('2.3.0') and dist.is_initialized():
if algo_name in [
'Alibi',
'BlurPool',
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ def test_strict_errors(self, missing_key: bool, unexpected_key: bool):
last_checkpoint = os.path.join('first', 'ep2.pt')
if missing_key or unexpected_key:
message = r'Error\(s\) in loading state_dict'
if version.parse(torch.__version__) < version.parse('2.2.9'):
if version.parse(torch.__version__) < version.parse('2.2.3') or not dist.is_initialized():
# Composer implements strict for older torch versions
message = 'Failed to load checkpoint due to'
error_context = pytest.raises(RuntimeError, match=message)
Expand Down Expand Up @@ -1229,7 +1229,7 @@ def test_autoload_algorithm_old_checkpoint(self):
NoOpModel.__init__ = lambda self, x: None # type: ignore
NoOpModel.__repr__ = lambda self: 'NoOpModel(3)'
error_context = pytest.raises(KeyError, match='module.0.weight')
if version.parse(torch.__version__) < version.parse('2.2.9'):
if version.parse(torch.__version__) < version.parse('2.2.3') or not dist.is_initialized():
error_context = pytest.raises(ValueError, match='loaded state dict contains a parameter group.*')
with pytest.warns(UserWarning, match='required_on_load algorithm.*'), error_context:
trainer_3 = self.get_trainer(load_path=os.path.join('first', 'ep1.pt'))
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def test_fsdp_load_old_checkpoint(
'state': trainer2.state.state_dict(),
'rng': get_rng_state(),
}
if version.parse(torch.__version__) < version.parse('2.2.9'):
if version.parse(torch.__version__) < version.parse('2.2.3'):
state_dict['state'].pop('optimizers')

object_store = S3ObjectStore(bucket=f'{s3_bucket}')
Expand All @@ -543,7 +543,7 @@ def test_fsdp_load_old_checkpoint(
planner=None,
process_group=process_group,
)
if version.parse(torch.__version__) < version.parse('2.2.9'):
if version.parse(torch.__version__) < version.parse('2.2.3'):
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model_state_dict = state_dict['state']['model']
Expand Down

0 comments on commit dcf8828

Please sign in to comment.