diff --git a/brax/envs/wrappers/gym.py b/brax/envs/wrappers/gym.py index 69ff2e9b..c1ed97c5 100644 --- a/brax/envs/wrappers/gym.py +++ b/brax/envs/wrappers/gym.py @@ -17,9 +17,11 @@ from brax.envs.base import PipelineEnv from brax.io import image -import gym -from gym import spaces -from gym.vector import utils + +import gymnasium as gym +from gymnasium import spaces +from gymnasium.vector import utils + import jax import numpy as np @@ -72,7 +74,7 @@ def reset(self): def step(self, action): self._state, obs, reward, done, info = self._step(self._state, action) # We return device arrays for pytorch users. - return obs, reward, done, info + return gym.utils.step_api_compatibility.convert_to_terminated_truncated_step_api((obs, reward, done, info)) def seed(self, seed: int = 0): self._key = jax.random.PRNGKey(seed) @@ -129,7 +131,7 @@ def reset(key): def step(state, action): state = self._env.step(state, action) info = {**state.metrics, **state.info} - return state, state.obs, state.reward, state.done, info + return gym.utils.step_api_compatibility.convert_to_terminated_truncated_step_api((state, state.obs, state.reward, state.done, info), is_vector_env=True) self._step = jax.jit(step, backend=self.backend) diff --git a/brax/v1/envs/wrappers.py b/brax/v1/envs/wrappers.py index 6326c103..dffb065d 100644 --- a/brax/v1/envs/wrappers.py +++ b/brax/v1/envs/wrappers.py @@ -19,12 +19,11 @@ from brax.v1 import jumpy as jp from brax.v1.envs import env as brax_env -import dm_env -from dm_env import specs -import flax + import gym from gym import spaces from gym.vector import utils + import jax import jax.numpy as jnp diff --git a/setup.py b/setup.py index 13bd2c16..520523a7 100644 --- a/setup.py +++ b/setup.py @@ -45,8 +45,8 @@ "flask_cors", "flax", # TODO: remove grpcio and gym after dropping legacy v1 code + "gymnasium", "grpcio", - "gym", "jax>=0.4.6", "jaxlib>=0.4.6", "jaxopt",