diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 92ba8ffccc2..bb0d2005a13 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -234,6 +234,7 @@ def update(batch, num_network_updates): # Compute GAE with torch.no_grad(), timeit("adv"): + torch.compiler.cudagraph_mark_step_begin() data = adv_module(data) with timeit("rb - extend"): # Update the data buffer diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index bd149d60f5c..8615fb47084 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -223,7 +223,8 @@ def update(batch, num_network_updates): # Compute GAE with torch.no_grad(), timeit("adv"): - data = adv_module(data.to(device)) + torch.compiler.cudagraph_mark_step_begin() + data = adv_module(data) with timeit("rb - extend"): # Update the data buffer data_reshape = data.reshape(-1)