Skip to content

Commit

Permalink
[Refactor] Use default device instead of CPU in losses
Browse files Browse the repository at this point in the history
ghstack-source-id: 8b98062c3ae88d8780ef7428fdfa07e305c790b9
Pull Request resolved: #2687
  • Loading branch information
vmoens committed Jan 16, 2025
1 parent 256a700 commit c3b9d1d
Show file tree
Hide file tree
Showing 7 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def __init__(
try:
device = next(self.parameters()).device
except AttributeError:
device = torch.device("cpu")
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
if bool(min_alpha) ^ bool(max_alpha):
min_alpha = min_alpha if min_alpha else 0.0
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def __init__(
try:
device = next(self.parameters()).device
except AttributeError:
device = torch.device("cpu")
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
if bool(min_alpha) ^ bool(max_alpha):
min_alpha = min_alpha if min_alpha else 0.0
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(
try:
device = next(self.parameters()).device
except AttributeError:
device = torch.device("cpu")
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()

self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
if bool(min_alpha) ^ bool(max_alpha):
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def __init__(
try:
device = next(self.parameters()).device
except AttributeError:
device = torch.device("cpu")
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()

self.register_buffer("alpha_init", torch.as_tensor(alpha_init, device=device))
self.register_buffer(
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def __init__(
try:
device = next(self.parameters()).device
except (AttributeError, StopIteration):
device = torch.device("cpu")
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()

self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device))
if critic_coef is not None:
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def __init__(
try:
device = next(self.parameters()).device
except AttributeError:
device = torch.device("cpu")
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()

self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
self.register_buffer(
Expand Down
4 changes: 2 additions & 2 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def __init__(
try:
device = next(self.parameters()).device
except AttributeError:
device = torch.device("cpu")
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()
self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
if bool(min_alpha) ^ bool(max_alpha):
min_alpha = min_alpha if min_alpha else 0.0
Expand Down Expand Up @@ -1121,7 +1121,7 @@ def __init__(
try:
device = next(self.parameters()).device
except AttributeError:
device = torch.device("cpu")
device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))()

self.register_buffer("alpha_init", torch.tensor(alpha_init, device=device))
if bool(min_alpha) ^ bool(max_alpha):
Expand Down

0 comments on commit c3b9d1d

Please sign in to comment.