Skip to content

Commit

Permalink
add test wiht snapshot
Browse files Browse the repository at this point in the history
  • Loading branch information
cli99 committed Feb 2, 2024
1 parent 5a74d34 commit ba6c859
Showing 1 changed file with 28 additions and 3 deletions.
31 changes: 28 additions & 3 deletions tests/callbacks/test_oom_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.utils.data import DataLoader

from composer import State, Trainer
from composer.callbacks import OOMObserver
from composer.callbacks import MemorySnapshot, OOMObserver
from composer.loggers import LoggerDestination
from composer.trainer import Trainer
from tests.common import RandomClassificationDataset, SimpleModel, device
Expand All @@ -31,6 +31,7 @@ def test_oom_observer_warnings_on_cpu_models(device: str):
max_duration='1ba',
)


class FileUploaderTracker(LoggerDestination):

def __init__(self) -> None:
Expand All @@ -45,8 +46,6 @@ def upload_file(self, state: State, remote_file_name: str, file_path: pathlib.Pa
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.1.0'),
reason='OOM Observer requires PyTorch 2.1 or higher')
def test_oom_observer():
if version.parse(torch.__version__) <= version.parse('2.1.0.dev'):
pytest.skip('oom_observer is supported after PyTorch 2.1.0.')

# Construct the callbacks
oom_observer = OOMObserver()
Expand All @@ -61,6 +60,7 @@ def test_oom_observer():
loggers=file_tracker_destination,
callbacks=oom_observer,
train_dataloader=DataLoader(RandomClassificationDataset()),
max_duration='2ba',
)

# trigger OOM
Expand All @@ -69,3 +69,28 @@ def test_oom_observer():
trainer.fit()

assert len(file_tracker_destination.uploaded_files) == 5


@pytest.mark.gpu
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.1.0'),
reason='OOM Observer requires PyTorch 2.1 or higher')
def test_oom_observer_with_memory_snapshot():

# Construct the callbacks
oom_observer = OOMObserver()
memory_snapshot = MemorySnapshot(skip_batches=0, interval='1ba')

simple_model = SimpleModel()

file_tracker_destination = FileUploaderTracker()

trainer = Trainer(
model=simple_model,
loggers=file_tracker_destination,
callbacks=[oom_observer, memory_snapshot],
train_dataloader=DataLoader(RandomClassificationDataset()),
max_duration='2ba',
)

trainer.fit()
assert len(file_tracker_destination.uploaded_files) == 1

0 comments on commit ba6c859

Please sign in to comment.