Skip to content

Commit

Permalink
Remove unnecessary state dict and load_state_dict functions (mosaicml…
Browse files Browse the repository at this point in the history
  • Loading branch information
eracah authored Jun 6, 2024
1 parent 7d7f888 commit 26084cd
Show file tree
Hide file tree
Showing 8 changed files with 7 additions and 77 deletions.
4 changes: 1 addition & 3 deletions composer/devices/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@
import torch.nn
from torch.optim import Optimizer

from composer.core.serializable import Serializable

__all__ = ['Device', 'T_nnModule']

T_nnModule = TypeVar('T_nnModule', bound=torch.nn.Module)
T_Batch = TypeVar('T_Batch')


class Device(Serializable, ABC):
class Device(ABC):
"""Abstract class for a device on which a model runs.
Attributes:
Expand Down
10 changes: 1 addition & 9 deletions composer/devices/device_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

import logging
from typing import Any, TypeVar
from typing import TypeVar

import torch

Expand Down Expand Up @@ -34,11 +34,3 @@ def module_to_device(self, module: T_nnModule) -> T_nnModule:

def tensor_to_device(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(self._device)

def state_dict(self) -> dict[str, Any]:
# CPU device has no RNG state
return {}

def load_state_dict(self, state: dict[str, Any]) -> None:
if len(state) != 0:
raise ValueError('CPU device has no state.')
10 changes: 1 addition & 9 deletions composer/devices/device_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from __future__ import annotations

from typing import Any, Optional, TypeVar
from typing import Optional, TypeVar

import torch
import torch.backends.cuda
Expand Down Expand Up @@ -57,11 +57,3 @@ def module_to_device(self, module: T_nnModule) -> T_nnModule:

def tensor_to_device(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(self._device, non_blocking=True)

def state_dict(self) -> dict[str, Any]:
return {
'rng': torch.cuda.get_rng_state(),
}

def load_state_dict(self, state: dict[str, Any]) -> None:
torch.cuda.set_rng_state(state['rng'])
10 changes: 1 addition & 9 deletions composer/devices/device_hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

import logging
from typing import Any, TypeVar
from typing import TypeVar

import torch

Expand Down Expand Up @@ -34,11 +34,3 @@ def module_to_device(self, module: T_nnModule) -> T_nnModule:

def tensor_to_device(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(self._device)

def state_dict(self) -> dict[str, Any]:
# HPU device has no RNG state
return {}

def load_state_dict(self, state: dict[str, Any]) -> None:
if len(state) != 0:
raise ValueError('HPU device has no state.')
9 changes: 1 addition & 8 deletions composer/devices/device_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from __future__ import annotations

from typing import Any, TypeVar
from typing import TypeVar

import torch
import torch.cuda.amp
Expand Down Expand Up @@ -42,10 +42,3 @@ def module_to_device(self, module: T_nnModule) -> T_nnModule:

def tensor_to_device(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(self._device)

def state_dict(self) -> dict[str, Any]:
return {}

def load_state_dict(self, state: dict[str, Any]) -> None:
if len(state) != 0:
raise ValueError('MPS device has no state.')
9 changes: 1 addition & 8 deletions composer/devices/device_neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import logging
import os
from typing import Any, TypeVar
from typing import TypeVar

import torch

Expand Down Expand Up @@ -43,10 +43,3 @@ def module_to_device(self, module: T_nnModule) -> T_nnModule:

def tensor_to_device(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(self._device)

def state_dict(self) -> dict[str, Any]:
return {}

def load_state_dict(self, state: dict[str, Any]) -> None:
if len(state) != 0:
raise ValueError('Neuron device has no state.')
9 changes: 1 addition & 8 deletions composer/devices/device_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

import logging
from typing import Any, TypeVar
from typing import TypeVar

import torch

Expand Down Expand Up @@ -40,10 +40,3 @@ def module_to_device(self, module: T_nnModule) -> T_nnModule:

def tensor_to_device(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(self._device)

def state_dict(self) -> dict[str, Any]:
return {}

def load_state_dict(self, state: dict[str, Any]) -> None:
if len(state) != 0:
raise ValueError('TPU device has no state.')
23 changes: 0 additions & 23 deletions composer/loggers/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,29 +184,6 @@ def log_images(
else:
wandb.log({name: list(wandb_images)}, step=step)

def state_dict(self) -> dict[str, Any]:
import wandb

# Storing these fields in the state dict to support run resuming in the future.
if self._enabled:
if wandb.run is None:
raise ValueError('wandb module must be initialized before serialization.')

# If WandB is disabled, most things are RunDisabled objects, which are not
# pickleable due to overriding __getstate__ but not __setstate__
if wandb.run.disabled:
return {}
else:
return {
'name': wandb.run.name,
'project': wandb.run.project,
'entity': wandb.run.entity,
'id': wandb.run.id,
'group': wandb.run.group,
}
else:
return {}

def init(self, state: State, logger: Logger) -> None:
import wandb
del logger # unused
Expand Down

0 comments on commit 26084cd

Please sign in to comment.