diff --git a/composer/callbacks/oom_observer.py b/composer/callbacks/oom_observer.py index f16248d9d5..f8f54dec9c 100644 --- a/composer/callbacks/oom_observer.py +++ b/composer/callbacks/oom_observer.py @@ -86,7 +86,7 @@ def __init__( else: self.remote_path_in_bucket = None - if version.parse(torch.__version__) > version.parse('2.1.0.dev'): # type: ignore + if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.1.0'): # type: ignore # OOMObserver is only supported in torch v2.1.0-rc1 or higher self._enabled = True else: