diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index 334a486e7e2..3c73c8ca8f8 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -112,8 +112,12 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create optimizers - actor_optim = torch.optim.Adam(actor.parameters(), lr=cfg.optim.lr, eps=1e-5) - critic_optim = torch.optim.Adam(critic.parameters(), lr=cfg.optim.lr, eps=1e-5) + actor_optim = torch.optim.Adam( + actor.parameters(), lr=torch.tensor(cfg.optim.lr, device=device), eps=1e-5 + ) + critic_optim = torch.optim.Adam( + critic.parameters(), lr=torch.tensor(cfg.optim.lr, device=device), eps=1e-5 + ) optim = group_optimizers(actor_optim, critic_optim) del actor_optim, critic_optim