From 04c0f5fbda9c0eb3360732ac0c9c7b544ebef694 Mon Sep 17 00:00:00 2001 From: NripeshN Date: Wed, 8 Nov 2023 01:55:54 +0400 Subject: [PATCH 1/2] Add device selection logic based on availability --- .../rastervision/pytorch_learner/learner.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py index bfd06b724..a7f561115 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py @@ -144,8 +144,15 @@ def __init__(self, self._tmp_dir = get_tmp_dir() tmp_dir = self._tmp_dir.name self.tmp_dir = tmp_dir - self.device = torch.device('cuda' - if torch.cuda.is_available() else 'cpu') + + if torch.backends.mps.is_available(): + DEFAULT_DEVICE = "mps" + elif torch.cuda.is_available(): + DEFAULT_DEVICE = "cuda" + else: + DEFAULT_DEVICE = "cpu" + + self.device = torch.device(DEFAULT_DEVICE) self.train_ds = train_ds self.valid_ds = valid_ds From 14191e618a6bd984efa8701e64a02ee8dbaa73e4 Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Mon, 20 Nov 2023 13:37:44 -0500 Subject: [PATCH 2/2] fix formatting and slightly refactor --- .../rastervision/pytorch_learner/learner.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py index a7f561115..108564966 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py @@ -144,15 +144,15 @@ def __init__(self, self._tmp_dir = get_tmp_dir() tmp_dir = self._tmp_dir.name self.tmp_dir = tmp_dir - - if torch.backends.mps.is_available(): - DEFAULT_DEVICE = "mps" - elif torch.cuda.is_available(): - DEFAULT_DEVICE = "cuda" + + if torch.cuda.is_available(): + device = 'cuda' + elif torch.backends.mps.is_available(): + device = 'mps' else: - DEFAULT_DEVICE = "cpu" - - self.device = torch.device(DEFAULT_DEVICE) + device = 'cpu' + + self.device = torch.device(device) self.train_ds = train_ds self.valid_ds = valid_ds