diff --git a/tests/callbacks/test_oom_observer.py b/tests/callbacks/test_oom_observer.py index fb5d217f55..e81fb3520a 100644 --- a/tests/callbacks/test_oom_observer.py +++ b/tests/callbacks/test_oom_observer.py @@ -17,15 +17,11 @@ @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.1.0'), reason='OOM Observer requires PyTorch 2.1 or higher') -def test_oom_observer_warnings_on_cpu_models(device: str): - - # Error if the user sets device=cpu even when cuda is available - del device # unused. always using cpu +def test_oom_observer_warnings_on_cpu_models(): ob = OOMObserver() Trainer( model=SimpleModel(), callbacks=ob, - device='cpu', train_dataloader=DataLoader(RandomClassificationDataset()), max_duration='1ba', )