diff --git a/conftest.py b/conftest.py index 7f6129515..6cbb8a09e 100644 --- a/conftest.py +++ b/conftest.py @@ -38,6 +38,12 @@ def pytest_addoption(parser): action="store_true", help="Run benchmarks on mps only and ignore machine configuration checks", ) + parser.addoption( + "--device_only", + action="store", + default=None, + help="Run benchmarks on the specific device only and ignore machine configuration checks", + ) def set_fuser(fuser): diff --git a/test_bench.py b/test_bench.py index 724b05d20..72db81eb5 100644 --- a/test_bench.py +++ b/test_bench.py @@ -31,6 +31,9 @@ def pytest_generate_tests(metafunc): if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): devices.append("mps") + if device_only := metafunc.config.option.device_only: + devices = [device_only] + if metafunc.config.option.cpu_only: devices = ["cpu"]