diff --git a/brax/actuator_test.py b/brax/actuator_test.py index 1b474501..34a386e3 100644 --- a/brax/actuator_test.py +++ b/brax/actuator_test.py @@ -32,6 +32,7 @@ def _actuator_step(pipeline, sys, q, qd, act, dt, n): sys = sys.tree_replace({'opt.timestep': dt}) + def f(state, _): return jax.jit(pipeline.step)(sys, state, act), None diff --git a/brax/base.py b/brax/base.py index e97f20e5..726a15e9 100644 --- a/brax/base.py +++ b/brax/base.py @@ -284,8 +284,8 @@ class Inertia(Base): """Angular inertia, mass, and center of mass location. Attributes: - transform: transform for the inertial frame relative to the link frame - (i.e. center of mass position and orientation) + transform: transform for the inertial frame relative to the link frame (i.e. + center of mass position and orientation) i: (3, 3) inertia matrix about a point P mass: scalar mass """ diff --git a/brax/com.py b/brax/com.py index db26b95d..cb726c01 100644 --- a/brax/com.py +++ b/brax/com.py @@ -13,6 +13,7 @@ # limitations under the License. """Helper functions for physics calculations in maximal coordinates.""" + # pylint:disable=g-multiple-import from typing import Tuple diff --git a/brax/com_test.py b/brax/com_test.py index 90195339..8168d8af 100644 --- a/brax/com_test.py +++ b/brax/com_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for com.""" + # pylint:disable=g-multiple-import from absl.testing import absltest from brax import com diff --git a/brax/envs/half_cheetah.py b/brax/envs/half_cheetah.py index 5b89bc66..2f8cbe15 100644 --- a/brax/envs/half_cheetah.py +++ b/brax/envs/half_cheetah.py @@ -178,7 +178,7 @@ def reset(self, rng: jax.Array) -> State: def step(self, state: State, action: jax.Array) -> State: """Runs one timestep of the environment's dynamics.""" pipeline_state0 = state.pipeline_state - assert pipeline_state0 is not None + assert pipeline_state0 is not None pipeline_state = self.pipeline_step(pipeline_state0, action) x_velocity = ( diff --git a/brax/envs/hopper.py b/brax/envs/hopper.py index 516a6336..202e3993 100644 --- a/brax/envs/hopper.py +++ b/brax/envs/hopper.py @@ -223,7 +223,7 @@ def reset(self, rng: jax.Array) -> State: def step(self, state: State, action: jax.Array) -> State: """Runs one timestep of the environment's dynamics.""" pipeline_state0 = state.pipeline_state - assert pipeline_state0 is not None + assert pipeline_state0 is not None pipeline_state = self.pipeline_step(pipeline_state0, action) x_velocity = ( diff --git a/brax/envs/humanoid.py b/brax/envs/humanoid.py index d593e598..8e5a17d1 100644 --- a/brax/envs/humanoid.py +++ b/brax/envs/humanoid.py @@ -324,7 +324,8 @@ def _get_obs( com_velocity = jp.hstack([com_vel, com_ang]) qfrc_actuator = actuator.to_tau( - self.sys, action, pipeline_state.q, pipeline_state.qd) + self.sys, action, pipeline_state.q, pipeline_state.qd + ) # external_contact_forces are excluded return jp.concatenate([ diff --git a/brax/envs/humanoidstandup.py b/brax/envs/humanoidstandup.py index c71139fe..024a9394 100644 --- a/brax/envs/humanoidstandup.py +++ b/brax/envs/humanoidstandup.py @@ -261,7 +261,8 @@ def _get_obs( com_velocity = jp.hstack([com_vel, com_ang]) qfrc_actuator = actuator.to_tau( - self.sys, action, pipeline_state.q, pipeline_state.qd) + self.sys, action, pipeline_state.q, pipeline_state.qd + ) # external_contact_forces are excluded return jp.concatenate([ diff --git a/brax/envs/inverted_double_pendulum.py b/brax/envs/inverted_double_pendulum.py index 8916fd21..b7dd2b60 100644 --- a/brax/envs/inverted_double_pendulum.py +++ b/brax/envs/inverted_double_pendulum.py @@ -118,8 +118,7 @@ class InvertedDoublePendulum(PipelineEnv): def __init__(self, backend='generalized', **kwargs): path = ( - epath.resource_path('brax') - / 'envs/assets/inverted_double_pendulum.xml' + epath.resource_path('brax') / 'envs/assets/inverted_double_pendulum.xml' ) sys = mjcf.load(path) @@ -176,12 +175,10 @@ def action_size(self): def _get_obs(self, pipeline_sate: base.State) -> jax.Array: """Observe cartpole body position and velocities.""" - return jp.concatenate( - [ - pipeline_sate.q[:1], # cart x pos - jp.sin(pipeline_sate.q[1:]), - jp.cos(pipeline_sate.q[1:]), - jp.clip(pipeline_sate.qd, -10, 10), - # qfrc_constraint is not added - ] - ) + return jp.concatenate([ + pipeline_sate.q[:1], # cart x pos + jp.sin(pipeline_sate.q[1:]), + jp.cos(pipeline_sate.q[1:]), + jp.clip(pipeline_sate.qd, -10, 10), + # qfrc_constraint is not added + ]) diff --git a/brax/envs/swimmer.py b/brax/envs/swimmer.py index 41f04ba4..8e3dba14 100644 --- a/brax/envs/swimmer.py +++ b/brax/envs/swimmer.py @@ -107,13 +107,15 @@ class Swimmer(PipelineEnv): # pyformat: enable - def __init__(self, - forward_reward_weight=1.0, - ctrl_cost_weight=1e-4, - reset_noise_scale=0.1, - exclude_current_positions_from_observation=True, - backend='generalized', - **kwargs): + def __init__( + self, + forward_reward_weight=1.0, + ctrl_cost_weight=1e-4, + reset_noise_scale=0.1, + exclude_current_positions_from_observation=True, + backend='generalized', + **kwargs, + ): path = epath.resource_path('brax') / 'envs/assets/swimmer.xml' sys = mjcf.load(path) @@ -130,7 +132,8 @@ def __init__(self, self._ctrl_cost_weight = ctrl_cost_weight self._reset_noise_scale = reset_noise_scale self._exclude_current_positions_from_observation = ( - exclude_current_positions_from_observation) + exclude_current_positions_from_observation + ) def reset(self, rng: jax.Array) -> State: rng, rng1, rng2 = jax.random.split(rng, 3) @@ -157,7 +160,8 @@ def step(self, state: State, action: jax.Array) -> State: if pipeline_state0 is None: raise AssertionError( - 'Cannot compute rewards with pipeline_state0 as Nonetype.') + 'Cannot compute rewards with pipeline_state0 as Nonetype.' + ) xy_position = pipeline_state.q[:2] diff --git a/brax/envs/walker2d.py b/brax/envs/walker2d.py index 54379e63..a13503e4 100644 --- a/brax/envs/walker2d.py +++ b/brax/envs/walker2d.py @@ -203,7 +203,7 @@ def reset(self, rng: jax.Array) -> State: def step(self, state: State, action: jax.Array) -> State: """Runs one timestep of the environment's dynamics.""" pipeline_state0 = state.pipeline_state - assert pipeline_state0 is not None + assert pipeline_state0 is not None pipeline_state = self.pipeline_step(pipeline_state0, action) x_velocity = ( diff --git a/brax/envs/wrappers/dm_env.py b/brax/envs/wrappers/dm_env.py index 2e7a290c..ac9013fb 100644 --- a/brax/envs/wrappers/dm_env.py +++ b/brax/envs/wrappers/dm_env.py @@ -13,6 +13,7 @@ # limitations under the License. """Wrappers to convert brax envs to DM Env envs.""" + from typing import Optional from brax.envs.base import PipelineEnv @@ -27,10 +28,9 @@ class DmEnvWrapper(dm_env.Environment): """A wrapper that converts Brax Env to one that follows Dm Env API.""" - def __init__(self, - env: PipelineEnv, - seed: int = 0, - backend: Optional[str] = None): + def __init__( + self, env: PipelineEnv, seed: int = 0, backend: Optional[str] = None + ): self._env = env self.seed(seed) self.backend = backend @@ -40,25 +40,32 @@ def __init__(self, self._observation_spec = self._env.observation_spec() else: obs_high = jp.inf * jp.ones(self._env.observation_size, dtype='float32') - self._observation_spec = specs.BoundedArray((self._env.observation_size,), - minimum=-obs_high, - maximum=obs_high, - dtype='float32', - name='observation') + self._observation_spec = specs.BoundedArray( + (self._env.observation_size,), + minimum=-obs_high, + maximum=obs_high, + dtype='float32', + name='observation', + ) if hasattr(self._env, 'action_spec'): self._action_spec = self._env.action_spec() else: action = jax.tree.map(np.array, self._env.sys.actuator.ctrl_range) - self._action_spec = specs.BoundedArray((self._env.action_size,), - minimum=action[:, 0], - maximum=action[:, 1], - dtype='float32', - name='action') - - self._reward_spec = specs.Array(shape=(), dtype=jp.dtype('float32'), name='reward') + self._action_spec = specs.BoundedArray( + (self._env.action_size,), + minimum=action[:, 0], + maximum=action[:, 1], + dtype='float32', + name='action', + ) + + self._reward_spec = specs.Array( + shape=(), dtype=jp.dtype('float32'), name='reward' + ) self._discount_spec = specs.BoundedArray( - shape=(), dtype='float32', minimum=0., maximum=1., name='discount') + shape=(), dtype='float32', minimum=0.0, maximum=1.0, name='discount' + ) if hasattr(self._env, 'discount_spec'): self._discount_spec = self._env.discount_spec() @@ -81,8 +88,9 @@ def reset(self): return dm_env.TimeStep( step_type=dm_env.StepType.FIRST, reward=None, - discount=jp.float32(1.), - observation=obs) + discount=jp.float32(1.0), + observation=obs, + ) def step(self, action): self._state, obs, reward, done, info = self._step(self._state, action) @@ -90,8 +98,9 @@ def step(self, action): return dm_env.TimeStep( step_type=dm_env.StepType.MID if not done else dm_env.StepType.LAST, reward=reward, - discount=jp.float32(1.), - observation=obs) + discount=jp.float32(1.0), + observation=obs, + ) def seed(self, seed: int = 0): self._key = jax.random.PRNGKey(seed) diff --git a/brax/envs/wrappers/dm_env_test.py b/brax/envs/wrappers/dm_env_test.py index 7724adc0..40e1aaa2 100644 --- a/brax/envs/wrappers/dm_env_test.py +++ b/brax/envs/wrappers/dm_env_test.py @@ -27,9 +27,11 @@ def test_action_space(self): base_env = envs.create('pusher') env = dm_env.DmEnvWrapper(base_env) np.testing.assert_array_equal( - env.action_spec().minimum, base_env.sys.actuator.ctrl_range[:, 0]) + env.action_spec().minimum, base_env.sys.actuator.ctrl_range[:, 0] + ) np.testing.assert_array_equal( - env.action_spec().maximum, base_env.sys.actuator.ctrl_range[:, 1]) + env.action_spec().maximum, base_env.sys.actuator.ctrl_range[:, 1] + ) if __name__ == '__main__': diff --git a/brax/envs/wrappers/gym.py b/brax/envs/wrappers/gym.py index 13ba55cb..fee40339 100644 --- a/brax/envs/wrappers/gym.py +++ b/brax/envs/wrappers/gym.py @@ -13,6 +13,7 @@ # limitations under the License. """Wrappers to convert brax envs to gym envs.""" + from typing import ClassVar, Optional from brax.envs.base import PipelineEnv @@ -31,14 +32,13 @@ class GymWrapper(gym.Env): # `_reset` as signs of a deprecated gym Env API. _gym_disable_underscore_compat: ClassVar[bool] = True - def __init__(self, - env: PipelineEnv, - seed: int = 0, - backend: Optional[str] = None): + def __init__( + self, env: PipelineEnv, seed: int = 0, backend: Optional[str] = None + ): self._env = env self.metadata = { 'render.modes': ['human', 'rgb_array'], - 'video.frames_per_second': 1 / self._env.dt + 'video.frames_per_second': 1 / self._env.dt, } self.seed(seed) self.backend = backend @@ -94,14 +94,13 @@ class VectorGymWrapper(gym.vector.VectorEnv): # `_reset` as signs of a deprecated gym Env API. _gym_disable_underscore_compat: ClassVar[bool] = True - def __init__(self, - env: PipelineEnv, - seed: int = 0, - backend: Optional[str] = None): + def __init__( + self, env: PipelineEnv, seed: int = 0, backend: Optional[str] = None + ): self._env = env self.metadata = { 'render.modes': ['human', 'rgb_array'], - 'video.frames_per_second': 1 / self._env.dt + 'video.frames_per_second': 1 / self._env.dt, } if not hasattr(self._env, 'batch_size'): raise ValueError('underlying env must be batched') diff --git a/brax/envs/wrappers/gym_test.py b/brax/envs/wrappers/gym_test.py index 6e668847..7569ac5b 100644 --- a/brax/envs/wrappers/gym_test.py +++ b/brax/envs/wrappers/gym_test.py @@ -28,9 +28,11 @@ def test_action_space(self): base_env = envs.create('pusher') env = gym.GymWrapper(base_env) np.testing.assert_array_equal( - env.action_space.low, base_env.sys.actuator.ctrl_range[:, 0]) + env.action_space.low, base_env.sys.actuator.ctrl_range[:, 0] + ) np.testing.assert_array_equal( - env.action_space.high, base_env.sys.actuator.ctrl_range[:, 1]) + env.action_space.high, base_env.sys.actuator.ctrl_range[:, 1] + ) def test_vector_action_space(self): @@ -39,10 +41,12 @@ def test_vector_action_space(self): env = gym.VectorGymWrapper(training.VmapWrapper(base_env, batch_size=256)) np.testing.assert_array_equal( env.action_space.low, - np.tile(base_env.sys.actuator.ctrl_range[:, 0], [256, 1])) + np.tile(base_env.sys.actuator.ctrl_range[:, 0], [256, 1]), + ) np.testing.assert_array_equal( env.action_space.high, - np.tile(base_env.sys.actuator.ctrl_range[:, 1], [256, 1])) + np.tile(base_env.sys.actuator.ctrl_range[:, 1], [256, 1]), + ) if __name__ == '__main__': diff --git a/brax/envs/wrappers/torch.py b/brax/envs/wrappers/torch.py index c0238797..83d0cdd6 100644 --- a/brax/envs/wrappers/torch.py +++ b/brax/envs/wrappers/torch.py @@ -16,6 +16,7 @@ This conversion happens directly on-device, without moving values to the CPU. """ + from typing import Optional # NOTE: The following line will emit a warning and raise ImportError if `torch` diff --git a/brax/envs/wrappers/training_test.py b/brax/envs/wrappers/training_test.py index 15727ce4..632ae42d 100644 --- a/brax/envs/wrappers/training_test.py +++ b/brax/envs/wrappers/training_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for training wrappers.""" + import functools from absl.testing import absltest diff --git a/brax/generalized/constraint.py b/brax/generalized/constraint.py index 29782cc2..78f2aa02 100644 --- a/brax/generalized/constraint.py +++ b/brax/generalized/constraint.py @@ -84,6 +84,7 @@ def point_jacobian( Returns: pt: point jacobian """ + # backward scan up tree: build the link mask corresponding to link_idx def mask_fn(mask_child, link): mask = link == link_idx diff --git a/brax/generalized/dynamics.py b/brax/generalized/dynamics.py index 6baf6076..b0b55a34 100644 --- a/brax/generalized/dynamics.py +++ b/brax/generalized/dynamics.py @@ -44,12 +44,10 @@ def transform_com(sys: System, state: State) -> State: cinr = x_i.replace(pos=x_i.pos - root_com).vmap().do(sys.link.inertia) # motion dofs to global frame centered at subtree-CoM - parent_idx = jp.array( - [ - i if t == 'f' else p - for i, (t, p) in enumerate(zip(sys.link_types, sys.link_parents)) - ] - ) + parent_idx = jp.array([ + i if t == 'f' else p + for i, (t, p) in enumerate(zip(sys.link_types, sys.link_parents)) + ]) parent = state.x.concatenate(Transform.zero(shape=(1,))).take(parent_idx) j = parent.vmap().do(sys.link.transform).vmap().do(sys.link.joint) @@ -150,6 +148,7 @@ def inverse(sys: System, state: State) -> jax.Array: Returns: tau: generalized forces resulting from joint positions and velocities """ + # forward scan over tree: accumulate link center of mass acceleration def cdd_fn(cdd_parent, cdofd, qd, dof_idx): if cdd_parent is None: @@ -187,6 +186,7 @@ def cfrc_fn(cfrc_child, cfrc): def _passive(sys: System, state: State) -> jax.Array: """Calculates the system's passive forces given input motion and position.""" + def stiffness_fn(typ, q, dof): if typ in 'fb': return jp.zeros_like(dof.stiffness) diff --git a/brax/generalized/dynamics_test.py b/brax/generalized/dynamics_test.py index 6a5dd5e3..0f707ed3 100644 --- a/brax/generalized/dynamics_test.py +++ b/brax/generalized/dynamics_test.py @@ -27,8 +27,11 @@ class DynamicsTest(parameterized.TestCase): @parameterized.parameters( - 'ant.xml', 'triple_pendulum.xml', ('humanoid.xml',), - ('half_cheetah.xml',), ('swimmer.xml',), + 'ant.xml', + 'triple_pendulum.xml', + ('humanoid.xml',), + ('half_cheetah.xml',), + ('swimmer.xml',), ) def test_transform_com(self, xml_file): """Test dynamics transform com.""" @@ -53,8 +56,11 @@ def test_transform_com(self, xml_file): np.testing.assert_almost_equal(state.cdofd.matrix(), mj_next.cdof_dot, 5) @parameterized.parameters( - 'ant.xml', 'triple_pendulum.xml', ('humanoid.xml',), - ('half_cheetah.xml',), ('swimmer.xml',), + 'ant.xml', + 'triple_pendulum.xml', + ('humanoid.xml',), + ('half_cheetah.xml',), + ('swimmer.xml',), ) def test_forward(self, xml_file): """Test dynamics forward.""" diff --git a/brax/generalized/mass.py b/brax/generalized/mass.py index de72648d..5a7cd535 100644 --- a/brax/generalized/mass.py +++ b/brax/generalized/mass.py @@ -39,6 +39,7 @@ def matrix(sys: System, state: State) -> jax.Array: a symmetric positive matrix (qd_size, qd_size) representing the generalized mass and inertia of the system """ + # backward scan up tree: accumulate composite link inertias def crb_fn(crb_child, crb): if crb_child is not None: diff --git a/brax/io/json.py b/brax/io/json.py index e7dac857..6247e362 100644 --- a/brax/io/json.py +++ b/brax/io/json.py @@ -16,7 +16,7 @@ """Saves a system config and trajectory as json.""" import json -from typing import Optional, List, Text, Tuple +from typing import List, Optional, Text, Tuple from brax.base import State, System from etils import epath diff --git a/brax/io/mjcf.py b/brax/io/mjcf.py index dbc7e81e..d69cb749 100644 --- a/brax/io/mjcf.py +++ b/brax/io/mjcf.py @@ -38,8 +38,10 @@ def _transform_do( - parent_pos: np.ndarray, parent_quat: np.ndarray, pos: np.ndarray, - quat: np.ndarray + parent_pos: np.ndarray, + parent_quat: np.ndarray, + pos: np.ndarray, + quat: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray]: pos = parent_pos + math.rotate_np(pos, parent_quat) rot = math.quat_mul_np(parent_quat, quat) @@ -47,7 +49,8 @@ def _transform_do( def _offset( - elem: ElementTree.Element, parent_pos: np.ndarray, parent_quat: np.ndarray): + elem: ElementTree.Element, parent_pos: np.ndarray, parent_quat: np.ndarray +): """Offsets an element.""" pos = elem.attrib.get('pos', '0 0 0') quat = elem.attrib.get('quat', '1 0 0 0') @@ -264,7 +267,8 @@ def validate_model(mj: mujoco.MjModel) -> None: non_free = np.concatenate([[j != 0] * q_width[j] for j in mj.jnt_type]) if mj.qpos0[non_free].any(): raise NotImplementedError( - 'The `ref` attribute on joint types is not supported.') + 'The `ref` attribute on joint types is not supported.' + ) for _, group in itertools.groupby( zip(mj.jnt_bodyid, mj.jnt_pos), key=lambda x: x[0] @@ -276,9 +280,7 @@ def validate_model(mj: mujoco.MjModel) -> None: # check dofs jnt_range = mj.jnt_range.copy() jnt_range[~(mj.jnt_limited == 1), :] = np.array([-np.inf, np.inf]) - for typ, limit, stiffness in zip( - mj.jnt_type, jnt_range, mj.jnt_stiffness - ): + for typ, limit, stiffness in zip(mj.jnt_type, jnt_range, mj.jnt_stiffness): if typ == 0: if stiffness > 0: raise RuntimeError('brax does not support stiffness for free joints') @@ -414,9 +416,7 @@ def load_model(mj: mujoco.MjModel) -> System: act_kwargs = jax.tree.map(lambda x: x[act_mask], act_kwargs) actuator = Actuator( # pytype: disable=wrong-arg-types - q_id=q_id, - qd_id=qd_id, - **act_kwargs + q_id=q_id, qd_id=qd_id, **act_kwargs ) # create non-pytree params. these do not live on device directly, and they diff --git a/brax/io/mjcf_test.py b/brax/io/mjcf_test.py index a70caa4c..c2e8676e 100644 --- a/brax/io/mjcf_test.py +++ b/brax/io/mjcf_test.py @@ -141,5 +141,6 @@ def test_loads_different_transmission(self): with self.assertRaisesRegex(NotImplementedError, 'transmission types'): mjcf.validate_model(mj) # raises an error + if __name__ == '__main__': absltest.main() diff --git a/brax/io/torch.py b/brax/io/torch.py index ce555664..7e66b63a 100644 --- a/brax/io/torch.py +++ b/brax/io/torch.py @@ -13,6 +13,7 @@ # limitations under the License. """Functions to convert Jax Arrays into PyTorch Tensors and vice-versa.""" + from collections import abc import functools from typing import Any, Dict, Union @@ -28,7 +29,8 @@ except ImportError: warnings.warn( "brax.io.torch requires PyTorch. Please run `pip install torch` to use " - "functions from this module.") + "functions from this module." + ) raise Device = Union[str, torch.device] @@ -57,7 +59,7 @@ def _tensor_to_jax(value: torch.Tensor) -> jax.Array: @torch_to_jax.register(abc.Mapping) def _torch_dict_to_jax( - value: Dict[str, Union[torch.Tensor, Any]] + value: Dict[str, Union[torch.Tensor, Any]], ) -> Dict[str, Union[jax.Array, Any]]: """Converts a dict of PyTorch tensors into a dict of jax.Arrays.""" return type(value)(**{k: torch_to_jax(v) for k, v in value.items()}) # type: ignore @@ -94,8 +96,9 @@ def _jaxarray_to_tensor( @jax_to_torch.register(abc.Mapping) def _jax_dict_to_torch( - value: Dict[str, Union[jax.Array, Any]], - device: Union[Device, None] = None) -> Dict[str, Union[torch.Tensor, Any]]: + value: Dict[str, Union[jax.Array, Any]], device: Union[Device, None] = None +) -> Dict[str, Union[torch.Tensor, Any]]: """Converts a dict of jax.Arrays into a dict of PyTorch tensors.""" return type(value)( - **{k: jax_to_torch(v, device=device) for k, v in value.items()}) # type: ignore + **{k: jax_to_torch(v, device=device) for k, v in value.items()} + ) # type: ignore diff --git a/brax/kinematics.py b/brax/kinematics.py index 6aee7ea7..33e60e6a 100644 --- a/brax/kinematics.py +++ b/brax/kinematics.py @@ -16,7 +16,7 @@ """Functions for forward and inverse kinematics.""" import functools -from typing import Tuple, Any +from typing import Any, Tuple from brax import base from brax import math @@ -160,8 +160,9 @@ def link_to_joint_frame(motion: Motion) -> Tuple[Motion, float]: joint might not be aligned with the rotational components of the joint. """ if motion.ang.shape[0] > 3 or motion.ang.shape[0] == 0: - raise AssertionError('Motion shape must be in (0, 3], ' - f'got {motion.ang.shape[0]}') + raise AssertionError( + f'Motion shape must be in (0, 3], got {motion.ang.shape[0]}' + ) # 1-dof if motion.ang.shape[0] == 1: @@ -372,6 +373,7 @@ def q_fn(typ, j, jd, parent_idx, motion): return jp.array(q).reshape(-1), jp.array(qd).reshape(-1) parent_idx = jp.array(sys.link_parents) - q, qd = scan.link_types(sys, q_fn, 'llld', 'qd', j, jd, parent_idx, - sys.dof.motion) + q, qd = scan.link_types( + sys, q_fn, 'llld', 'qd', j, jd, parent_idx, sys.dof.motion + ) return q, qd diff --git a/brax/kinematics_test.py b/brax/kinematics_test.py index ed927809..cb730725 100644 --- a/brax/kinematics_test.py +++ b/brax/kinematics_test.py @@ -41,20 +41,23 @@ def test_forward(self, xml_file): sys = test_utils.load_fixture(xml_file) for mj_prev, mj_next in test_utils.sample_mujoco_states( - xml_file, random_init=True, vel_to_local=False): + xml_file, random_init=True, vel_to_local=False + ): x, xd = jax.jit(kinematics.forward)(sys, mj_prev.qpos, mj_prev.qvel) np.testing.assert_almost_equal(x.pos, mj_next.xpos[1:], 3) # handle quat rotations +/- 2pi quat_sign = np.allclose( - np.sum(mj_next.xquat[1:]) - np.sum(x.rot), 0, atol=1e-2) + np.sum(mj_next.xquat[1:]) - np.sum(x.rot), 0, atol=1e-2 + ) quat_sign = 1 if quat_sign else -1 x = x.replace(rot=x.rot * quat_sign) np.testing.assert_almost_equal(x.rot, mj_next.xquat[1:], 3) # xd vel/ang were added to linvel/angmom in `sample_mujoco_states` xd_mj = Motion( - vel=mj_next.subtree_linvel[1:], ang=mj_next.subtree_angmom[1:]) + vel=mj_next.subtree_linvel[1:], ang=mj_next.subtree_angmom[1:] + ) if xml_file == 'humanoid.xml': # TODO: get forward to match MJ for stacked/offset joints diff --git a/brax/math.py b/brax/math.py index 97f78ebc..9519caf7 100644 --- a/brax/math.py +++ b/brax/math.py @@ -14,7 +14,7 @@ """Some useful math functions.""" -from typing import Tuple, Optional, Union +from typing import Optional, Tuple, Union import jax from jax import custom_jvp @@ -255,8 +255,7 @@ def orthogonals(a: jax.Array) -> Tuple[jax.Array, jax.Array]: def solve_pgs(a: jax.Array, b: jax.Array, num_iters: int) -> jax.Array: - """Projected Gauss-Seidel solver for a MLCP defined by matrix A and vector b. - """ + """Projected Gauss-Seidel solver for a MLCP defined by matrix A and vector b.""" num_rows = b.shape[0] x = jp.zeros((num_rows,)) diff --git a/brax/positional/collisions.py b/brax/positional/collisions.py index 862b74ff..b24c1378 100644 --- a/brax/positional/collisions.py +++ b/brax/positional/collisions.py @@ -13,12 +13,13 @@ # limitations under the License. """Functions to resolve collisions.""" + # pylint:disable=g-multiple-import from typing import Optional, Tuple from brax import com from brax import math -from brax.base import Contact, Motion, System, Transform, Force +from brax.base import Contact, Force, Motion, System, Transform from brax.positional.base import State import jax from jax import numpy as jp diff --git a/brax/positional/integrator.py b/brax/positional/integrator.py index c3b4ad23..8a34a108 100644 --- a/brax/positional/integrator.py +++ b/brax/positional/integrator.py @@ -13,6 +13,7 @@ # limitations under the License. """Functions for integrating maximal coordinate dynamics.""" + # pylint:disable=g-multiple-import from typing import Tuple diff --git a/brax/positional/joints.py b/brax/positional/joints.py index 15e78d61..6579ea7c 100644 --- a/brax/positional/joints.py +++ b/brax/positional/joints.py @@ -13,6 +13,7 @@ # limitations under the License. """Joint definition and apply functions.""" + # pylint:disable=g-multiple-import from typing import Tuple @@ -106,7 +107,8 @@ def position_update(sys: System, state: State) -> Transform: mass_inv = 1 / (sys.link.inertia.mass ** (1 - sys.spring_mass_scale)) mass_inv_p = mass_inv[p_idx] * (p_idx > -1) dp_p_pos, dp_c_pos = jax.vmap(_translation_update)( - a_p, xi_p, i_inv_p, mass_inv_p, a_c, state.x_i, i_inv, mass_inv, -d_w.pos) + a_p, xi_p, i_inv_p, mass_inv_p, a_c, state.x_i, i_inv, mass_inv, -d_w.pos + ) dp_p_ang, dp_c_ang = jax.vmap(_rotation_update)( xi_p, i_inv_p, state.x_i, i_inv, d_w.rot ) @@ -175,7 +177,10 @@ def _sphericalize(sys, j): def pad_free(_): # create dummy data for free links inf = jp.array([jp.inf, jp.inf, jp.inf]) - return (-inf, inf), Motion(ang=jp.eye(3), vel=jp.eye(3)), + return ( + (-inf, inf), + Motion(ang=jp.eye(3), vel=jp.eye(3)), + ) def pad_x_dof(dof, x): if dof.limit: diff --git a/brax/positional/pipeline.py b/brax/positional/pipeline.py index 4969d5d2..6b883107 100644 --- a/brax/positional/pipeline.py +++ b/brax/positional/pipeline.py @@ -13,6 +13,7 @@ # limitations under the License. """Physics pipeline for fully articulated dynamics and collisiion.""" + # pylint:disable=g-multiple-import from typing import Optional from brax import actuator diff --git a/brax/spring/base.py b/brax/spring/base.py index f8d8d306..f8ee082a 100644 --- a/brax/spring/base.py +++ b/brax/spring/base.py @@ -36,6 +36,7 @@ class State(base.State): i_inv: link inverse inertia mass: link mass """ + x_i: Transform xd_i: Motion j: Transform diff --git a/brax/spring/collisions.py b/brax/spring/collisions.py index 525481f1..269a4cb7 100644 --- a/brax/spring/collisions.py +++ b/brax/spring/collisions.py @@ -13,6 +13,7 @@ # limitations under the License. """Function to resolve collisions.""" + # pylint:disable=g-multiple-import from brax import contact from brax import math diff --git a/brax/spring/integrator.py b/brax/spring/integrator.py index c9e6ceb0..8c382e89 100644 --- a/brax/spring/integrator.py +++ b/brax/spring/integrator.py @@ -13,6 +13,7 @@ # limitations under the License. """Functions for integrating maximal coordinate dynamics.""" + # pylint:disable=g-multiple-import from typing import Tuple diff --git a/brax/spring/joints.py b/brax/spring/joints.py index dd4706f1..b453217c 100644 --- a/brax/spring/joints.py +++ b/brax/spring/joints.py @@ -13,6 +13,7 @@ # limitations under the License. """Joint definition and apply functions.""" + # pylint:disable=g-multiple-import from brax import kinematics from brax import math @@ -54,9 +55,7 @@ def _one_dof( is_rotational = dof.motion.ang.any() vel = ( vel - - jp.dot(joint_frame.vel[0], vel) - * joint_frame.vel[0] - * is_translational + - jp.dot(joint_frame.vel[0], vel) * joint_frame.vel[0] * is_translational ) # add in force @@ -66,9 +65,7 @@ def _one_dof( # if prismatic, don't damp along free axis vel += ( damp - - jp.dot(joint_frame.vel[0], damp) - * joint_frame.vel[0] - * is_translational + - jp.dot(joint_frame.vel[0], damp) * joint_frame.vel[0] * is_translational ) axis_c_x = math.rotate(joint_frame.ang[0], j.rot) diff --git a/brax/test_utils.py b/brax/test_utils.py index cc34d4cc..12be27e2 100644 --- a/brax/test_utils.py +++ b/brax/test_utils.py @@ -44,16 +44,23 @@ def _normalize_q(model: mujoco.MjModel, q: np.ndarray): for typ in model.jnt_type: q_dim = 7 if typ == 0 else 1 if typ == 0: - q[q_idx + 3:q_idx + 7] = ( - q[q_idx + 3:q_idx + 7] / np.linalg.norm(q[q_idx + 3:q_idx + 7])) + q[q_idx + 3 : q_idx + 7] = q[q_idx + 3 : q_idx + 7] / np.linalg.norm( + q[q_idx + 3 : q_idx + 7] + ) q_idx += q_dim return q def sample_mujoco_states( - path: str, count: int = 500, modulo: int = 20, force_pgs: bool = False, - random_init: bool = False, random_q_scale: float = 1.0, - random_qd_scale: float = 0.1, vel_to_local: bool = True, seed: int = 42 + path: str, + count: int = 500, + modulo: int = 20, + force_pgs: bool = False, + random_init: bool = False, + random_q_scale: float = 1.0, + random_qd_scale: float = 0.1, + vel_to_local: bool = True, + seed: int = 42, ) -> Iterable[Tuple[mujoco.MjData, mujoco.MjData]]: """Samples count / modulo states from mujoco for comparison.""" np.random.seed(seed) @@ -76,7 +83,8 @@ def sample_mujoco_states( for i in range(model.nbody): vel = np.zeros((6,)) mujoco.mj_objectVelocity( - model, data, mujoco.mjtObj.mjOBJ_XBODY.value, i, vel, vel_to_local) + model, data, mujoco.mjtObj.mjOBJ_XBODY.value, i, vel, vel_to_local + ) data.subtree_angmom[i] = vel[:3] data.subtree_linvel[i] = vel[3:] yield before, data diff --git a/brax/training/acting.py b/brax/training/acting.py index 14a7c595..ec6f81d3 100644 --- a/brax/training/acting.py +++ b/brax/training/acting.py @@ -36,7 +36,7 @@ def actor_step( env_state: State, policy: Policy, key: PRNGKey, - extra_fields: Sequence[str] = () + extra_fields: Sequence[str] = (), ) -> Tuple[State, Transition]: """Collect data.""" actions, policy_extras = policy(env_state.obs, key) @@ -48,10 +48,8 @@ def actor_step( reward=nstate.reward, discount=1 - nstate.done, next_observation=nstate.obs, - extras={ - 'policy_extras': policy_extras, - 'state_extras': state_extras - }) + extras={'policy_extras': policy_extras, 'state_extras': state_extras}, + ) def generate_unroll( @@ -60,7 +58,7 @@ def generate_unroll( policy: Policy, key: PRNGKey, unroll_length: int, - extra_fields: Sequence[str] = () + extra_fields: Sequence[str] = (), ) -> Tuple[State, Transition]: """Collect trajectories of given unroll_length.""" @@ -69,11 +67,13 @@ def f(carry, unused_t): state, current_key = carry current_key, next_key = jax.random.split(current_key) nstate, transition = actor_step( - env, state, policy, current_key, extra_fields=extra_fields) + env, state, policy, current_key, extra_fields=extra_fields + ) return (nstate, next_key), transition (final_state, _), data = jax.lax.scan( - f, (env_state, key), (), length=unroll_length) + f, (env_state, key), (), length=unroll_length + ) return final_state, data @@ -81,10 +81,15 @@ def f(carry, unused_t): class Evaluator: """Class to run evaluations.""" - def __init__(self, eval_env: envs.Env, - eval_policy_fn: Callable[[PolicyParams], - Policy], num_eval_envs: int, - episode_length: int, action_repeat: int, key: PRNGKey): + def __init__( + self, + eval_env: envs.Env, + eval_policy_fn: Callable[[PolicyParams], Policy], + num_eval_envs: int, + episode_length: int, + action_repeat: int, + key: PRNGKey, + ): """Init. Args: @@ -96,12 +101,13 @@ def __init__(self, eval_env: envs.Env, key: RNG key. """ self._key = key - self._eval_walltime = 0. + self._eval_walltime = 0.0 eval_env = envs.training.EvalWrapper(eval_env) - def generate_eval_unroll(policy_params: PolicyParams, - key: PRNGKey) -> State: + def generate_eval_unroll( + policy_params: PolicyParams, key: PRNGKey + ) -> State: reset_keys = jax.random.split(key, num_eval_envs) eval_first_state = eval_env.reset(reset_keys) return generate_unroll( @@ -109,15 +115,18 @@ def generate_eval_unroll(policy_params: PolicyParams, eval_first_state, eval_policy_fn(policy_params), key, - unroll_length=episode_length // action_repeat)[0] + unroll_length=episode_length // action_repeat, + )[0] self._generate_eval_unroll = jax.jit(generate_eval_unroll) self._steps_per_unroll = episode_length * num_eval_envs - def run_evaluation(self, - policy_params: PolicyParams, - training_metrics: Metrics, - aggregate_episodes: bool = True) -> Metrics: + def run_evaluation( + self, + policy_params: PolicyParams, + training_metrics: Metrics, + aggregate_episodes: bool = True, + ) -> Metrics: """Run one epoch of evaluation.""" self._key, unroll_key = jax.random.split(self._key) @@ -129,14 +138,12 @@ def run_evaluation(self, metrics = {} for fn in [np.mean, np.std]: suffix = '_std' if fn == np.std else '' - metrics.update( - { - f'eval/episode_{name}{suffix}': ( - fn(value) if aggregate_episodes else value - ) - for name, value in eval_metrics.episode_metrics.items() - } - ) + metrics.update({ + f'eval/episode_{name}{suffix}': ( + fn(value) if aggregate_episodes else value + ) + for name, value in eval_metrics.episode_metrics.items() + }) metrics['eval/avg_episode_length'] = np.mean(eval_metrics.episode_steps) metrics['eval/epoch_eval_time'] = epoch_eval_time metrics['eval/sps'] = self._steps_per_unroll / epoch_eval_time @@ -144,7 +151,7 @@ def run_evaluation(self, metrics = { 'eval/walltime': self._eval_walltime, **training_metrics, - **metrics + **metrics, } return metrics # pytype: disable=bad-return-type # jax-ndarray diff --git a/brax/training/agents/apg/networks.py b/brax/training/agents/apg/networks.py index 56254e63..3707f484 100644 --- a/brax/training/agents/apg/networks.py +++ b/brax/training/agents/apg/networks.py @@ -14,7 +14,7 @@ """APG networks.""" -from typing import Sequence,Tuple +from typing import Sequence, Tuple from brax.training import distribution from brax.training import networks @@ -33,16 +33,22 @@ class APGNetworks: def make_inference_fn(apg_networks: APGNetworks): """Creates params and inference function for the APG agent.""" - def make_policy(params: types.PolicyParams, - deterministic: bool = False) -> types.Policy: + def make_policy( + params: types.PolicyParams, deterministic: bool = False + ) -> types.Policy: - def policy(observations: types.Observation, - key_sample: PRNGKey) -> Tuple[types.Action, types.Extra]: + def policy( + observations: types.Observation, key_sample: PRNGKey + ) -> Tuple[types.Action, types.Extra]: logits = apg_networks.policy_network.apply(*params, observations) if deterministic: return apg_networks.parametric_action_distribution.mode(logits), {} - return apg_networks.parametric_action_distribution.sample( - logits, key_sample), {} + return ( + apg_networks.parametric_action_distribution.sample( + logits, key_sample + ), + {}, + ) return policy @@ -52,21 +58,25 @@ def policy(observations: types.Observation, def make_apg_networks( observation_size: int, action_size: int, - preprocess_observations_fn: types.PreprocessObservationFn = types - .identity_observation_preprocessor, + preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, hidden_layer_sizes: Sequence[int] = (32,) * 4, activation: networks.ActivationFn = linen.elu, - layer_norm: bool = True) -> APGNetworks: + layer_norm: bool = True, +) -> APGNetworks: """Make APG networks.""" parametric_action_distribution = distribution.NormalTanhDistribution( - event_size=action_size, var_scale=0.1) + event_size=action_size, var_scale=0.1 + ) policy_network = networks.make_policy_network( parametric_action_distribution.param_size, observation_size, preprocess_observations_fn=preprocess_observations_fn, - hidden_layer_sizes=hidden_layer_sizes, activation=activation, + hidden_layer_sizes=hidden_layer_sizes, + activation=activation, kernel_init=linen.initializers.orthogonal(0.01), - layer_norm=layer_norm) + layer_norm=layer_norm, + ) return APGNetworks( policy_network=policy_network, - parametric_action_distribution=parametric_action_distribution) + parametric_action_distribution=parametric_action_distribution, + ) diff --git a/brax/training/agents/apg/train.py b/brax/training/agents/apg/train.py index 9c4336fa..27ff0a3b 100644 --- a/brax/training/agents/apg/train.py +++ b/brax/training/agents/apg/train.py @@ -45,6 +45,7 @@ @flax.struct.dataclass class TrainingState: """Contains training state for the learner.""" + optimizer_state: optax.OptState normalizer_params: running_statistics.RunningStatisticsState policy_params: Params @@ -94,8 +95,13 @@ def train( local_devices_to_use = min(local_devices_to_use, max_devices_per_host) logging.info( 'Device count: %d, process count: %d (id %d), local device count: %d, ' - 'devices to be used count: %d', jax.device_count(), process_count, - process_id, local_device_count, local_devices_to_use) + 'devices to be used count: %d', + jax.device_count(), + process_count, + process_id, + local_device_count, + local_devices_to_use, + ) device_count = local_devices_to_use * process_count num_updates = policy_updates @@ -139,27 +145,31 @@ def train( if normalize_observations: normalize = running_statistics.normalize apg_network = network_factory( - obs_size, - env.action_size, - preprocess_observations_fn=normalize) + obs_size, env.action_size, preprocess_observations_fn=normalize + ) make_policy = apg_networks.make_inference_fn(apg_network) if use_schedule: learning_rate = optax.exponential_decay( - init_value=learning_rate, - transition_steps=1, - decay_rate=schedule_decay + init_value=learning_rate, transition_steps=1, decay_rate=schedule_decay ) optimizer = optax.chain( optax.clip(1.0), - optax.adam(learning_rate=learning_rate, b1=adam_b[0], b2=adam_b[1]) + optax.adam(learning_rate=learning_rate, b1=adam_b[0], b2=adam_b[1]), ) def scramble_times(state, key): state.info['steps'] = jnp.round( - jax.random.uniform(key, (local_devices_to_use, num_envs,), - maxval=episode_length)) + jax.random.uniform( + key, + ( + local_devices_to_use, + num_envs, + ), + maxval=episode_length, + ) + ) return state def env_step( @@ -283,7 +293,7 @@ def training_epoch_with_timing( metrics = { 'training/sps': sps, 'training/walltime': training_walltime, - **{f'training/{name}': value for name, value in metrics.items()} + **{f'training/{name}': value for name, value in metrics.items()}, } return training_state, env_state, metrics, key # pytype: disable=bad-return-type # py311-upgrade @@ -297,10 +307,12 @@ def training_epoch_with_timing( optimizer_state=optimizer.init(policy_params), policy_params=policy_params, normalizer_params=running_statistics.init_state( - specs.Array((env.observation_size,), jnp.dtype(dtype)))) + specs.Array((env.observation_size,), jnp.dtype(dtype)) + ), + ) training_state = jax.device_put_replicated( - training_state, - jax.local_devices()[:local_devices_to_use]) + training_state, jax.local_devices()[:local_devices_to_use] + ) if not eval_env: eval_env = environment @@ -322,7 +334,8 @@ def training_epoch_with_timing( num_eval_envs=num_eval_envs, episode_length=episode_length, action_repeat=action_repeat, - key=eval_key) + key=eval_key, + ) # Run initial eval metrics = {} @@ -376,6 +389,7 @@ def training_epoch_with_timing( # devices. pmap.assert_is_replicated(training_state) params = _unpmap( - (training_state.normalizer_params, training_state.policy_params)) + (training_state.normalizer_params, training_state.policy_params) + ) pmap.synchronize_hosts() return (make_policy, params, metrics) diff --git a/brax/training/agents/apg/train_test.py b/brax/training/agents/apg/train_test.py index 994099bd..26e55e5e 100644 --- a/brax/training/agents/apg/train_test.py +++ b/brax/training/agents/apg/train_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Analytic policy gradient tests.""" + import pickle from absl.testing import absltest @@ -52,13 +53,14 @@ def testNetworkEncoding(self, normalize_observations): num_envs=16, learning_rate=3e-3, normalize_observations=normalize_observations, - num_evals=200 + num_evals=200, ) normalize_fn = lambda x, y: x if normalize_observations: normalize_fn = running_statistics.normalize - apg_network = apg_networks.make_apg_networks(env.observation_size, - env.action_size, normalize_fn) + apg_network = apg_networks.make_apg_networks( + env.observation_size, env.action_size, normalize_fn + ) inference = apg_networks.make_inference_fn(apg_network) byte_encoding = pickle.dumps(params) decoded_params = pickle.loads(byte_encoding) @@ -66,7 +68,8 @@ def testNetworkEncoding(self, normalize_observations): # Compute one action. state = env.reset(jax.random.PRNGKey(0)) original_action = original_inference(decoded_params)( - state.obs, jax.random.PRNGKey(0))[0] + state.obs, jax.random.PRNGKey(0) + )[0] action = inference(decoded_params)(state.obs, jax.random.PRNGKey(0))[0] self.assertSequenceEqual(original_action, action) env.step(state, action) @@ -97,5 +100,6 @@ def get_offset(rng): randomization_fn=rand_fn, ) + if __name__ == '__main__': absltest.main() diff --git a/brax/training/agents/ars/networks.py b/brax/training/agents/ars/networks.py index 09f18186..7bb17821 100644 --- a/brax/training/agents/ars/networks.py +++ b/brax/training/agents/ars/networks.py @@ -27,8 +27,7 @@ def make_policy_network( observation_size: int, action_size: int, - preprocess_observations_fn: types.PreprocessObservationFn = types - .identity_observation_preprocessor, + preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, ) -> ARSNetwork: """Creates a policy network.""" @@ -37,7 +36,8 @@ def apply(processor_params, policy_params, obs): return jnp.matmul(obs, policy_params) return ARSNetwork( - init=lambda _: jnp.zeros((observation_size, action_size)), apply=apply) + init=lambda _: jnp.zeros((observation_size, action_size)), apply=apply + ) def make_inference_fn(policy_network: ARSNetwork): @@ -45,8 +45,9 @@ def make_inference_fn(policy_network: ARSNetwork): def make_policy(params: types.PolicyParams) -> types.Policy: - def policy(observations: types.Observation, - unused_key_sample: PRNGKey) -> Tuple[types.Action, types.Extra]: + def policy( + observations: types.Observation, unused_key_sample: PRNGKey + ) -> Tuple[types.Action, types.Extra]: return policy_network.apply(*params, observations), {} return policy diff --git a/brax/training/agents/ars/train.py b/brax/training/agents/ars/train.py index 42d851fe..9352fe7f 100644 --- a/brax/training/agents/ars/train.py +++ b/brax/training/agents/ars/train.py @@ -44,6 +44,7 @@ @flax.struct.dataclass class TrainingState: """Contains training state for the learner.""" + normalizer_params: running_statistics.RunningStatisticsState policy_params: Params num_env_steps: int @@ -81,15 +82,19 @@ def train( process_count = jax.process_count() if process_count > 1: - raise ValueError('ES is not compatible with multiple hosts, ' - 'please use a single host device.') + raise ValueError( + 'ES is not compatible with multiple hosts, ' + 'please use a single host device.' + ) local_device_count = jax.local_device_count() local_devices_to_use = local_device_count if max_devices_per_host: local_devices_to_use = min(local_devices_to_use, max_devices_per_host) - logging.info('Local device count: %d, ' - 'devices to be used count: %d', local_device_count, - local_devices_to_use) + logging.info( + 'Local device count: %d, devices to be used count: %d', + local_device_count, + local_devices_to_use, + ) num_env_steps_between_evals = num_timesteps // num_evals next_eval_step = num_timesteps - (num_evals - 1) * num_env_steps_between_evals @@ -128,33 +133,56 @@ def train( ars_network = network_factory( observation_size=obs_size, action_size=env.action_size, - preprocess_observations_fn=normalize_fn) + preprocess_observations_fn=normalize_fn, + ) make_policy = ars_networks.make_inference_fn(ars_network) vmapped_policy = jax.vmap(ars_network.apply, in_axes=(None, 0, 0)) def run_step(carry, unused_target_t): - (env_state, policy_params, cumulative_reward, active_episode, - normalizer_params) = carry + ( + env_state, + policy_params, + cumulative_reward, + active_episode, + normalizer_params, + ) = carry obs = env_state.obs actions = vmapped_policy(normalizer_params, policy_params, obs) nstate = env.step(env_state, actions) - cumulative_reward = cumulative_reward + (nstate.reward - - reward_shift) * active_episode + cumulative_reward = ( + cumulative_reward + (nstate.reward - reward_shift) * active_episode + ) new_active_episode = active_episode * (1 - nstate.done) - return (nstate, policy_params, cumulative_reward, new_active_episode, - normalizer_params), (env_state.obs, active_episode) - - def run_episode(normalizer_params: running_statistics.NestedMeanStd, - params: Params, key: PRNGKey): + return ( + nstate, + policy_params, + cumulative_reward, + new_active_episode, + normalizer_params, + ), (env_state.obs, active_episode) + + def run_episode( + normalizer_params: running_statistics.NestedMeanStd, + params: Params, + key: PRNGKey, + ): reset_keys = jax.random.split(key, num_envs // local_devices_to_use) first_env_states = env.reset(reset_keys) cumulative_reward = first_env_states.reward active_episode = jnp.ones_like(cumulative_reward) (_, _, cumulative_reward, _, _), (obs, obs_weights) = jax.lax.scan( - run_step, (first_env_states, params, cumulative_reward, active_episode, - normalizer_params), (), - length=episode_length // action_repeat) + run_step, + ( + first_env_states, + params, + cumulative_reward, + active_episode, + normalizer_params, + ), + (), + length=episode_length // action_repeat, + ) return cumulative_reward, obs, obs_weights def add_noise(params: Params, key: PRNGKey) -> Tuple[Params, Params, Params]: @@ -162,20 +190,24 @@ def add_noise(params: Params, key: PRNGKey) -> Tuple[Params, Params, Params]: treedef = jax.tree_util.tree_structure(params) all_keys = jax.random.split(key, num=num_vars) noise = jax.tree_util.tree_map( - lambda g, k: jax.random.normal(k, shape=g.shape, dtype=g.dtype), params, - jax.tree_util.tree_unflatten(treedef, all_keys)) + lambda g, k: jax.random.normal(k, shape=g.shape, dtype=g.dtype), + params, + jax.tree_util.tree_unflatten(treedef, all_keys), + ) params_with_noise = jax.tree_util.tree_map( lambda g, n: g + n * exploration_noise_std, params, noise ) params_with_anti_noise = jax.tree_util.tree_map( - lambda g, n: g - n * exploration_noise_std, params, noise) + lambda g, n: g - n * exploration_noise_std, params, noise + ) return params_with_noise, params_with_anti_noise, noise prun_episode = jax.pmap(run_episode, in_axes=(None, 0, 0)) @jax.jit - def training_epoch(training_state: TrainingState, - key: PRNGKey) -> Tuple[TrainingState, Metrics]: + def training_epoch( + training_state: TrainingState, key: PRNGKey + ) -> Tuple[TrainingState, Metrics]: params = jax.tree_util.tree_map( lambda x: jnp.repeat( jnp.expand_dims(x, axis=0), number_of_directions, axis=0 @@ -185,7 +217,8 @@ def training_epoch(training_state: TrainingState, key, key_noise, key_es_eval = jax.random.split(key, 3) # generate perturbations params_with_noise, params_with_anti_noise, noise = add_noise( - params, key_noise) + params, key_noise + ) pparams = jax.tree_util.tree_map( lambda a, b: jnp.concatenate([a, b], axis=0), @@ -195,17 +228,20 @@ def training_epoch(training_state: TrainingState, pparams = jax.tree_util.tree_map( lambda x: jnp.reshape(x, (local_devices_to_use, -1) + x.shape[1:]), - pparams) + pparams, + ) key_es_eval = jax.random.split(key_es_eval, local_devices_to_use) eval_scores, obs, obs_weights = prun_episode( - training_state.normalizer_params, pparams, key_es_eval) + training_state.normalizer_params, pparams, key_es_eval + ) obs = jnp.reshape(obs, (-1,) + obs.shape[2:]) obs_weights = jnp.reshape(obs_weights, (-1,) + obs_weights.shape[2:]) normalizer_params = running_statistics.update( - training_state.normalizer_params, obs, weights=obs_weights) + training_state.normalizer_params, obs, weights=obs_weights + ) eval_scores = jnp.reshape(eval_scores, [-1]) @@ -213,8 +249,9 @@ def training_epoch(training_state: TrainingState, reward_max = jnp.maximum(reward_plus, reward_minus) reward_rank = jnp.argsort(jnp.argsort(-reward_max)) reward_weight = jnp.where(reward_rank < top_directions, 1, 0) - reward_weight_double = jnp.concatenate([reward_weight, reward_weight], - axis=0) + reward_weight_double = jnp.concatenate( + [reward_weight, reward_weight], axis=0 + ) reward_std = jnp.std(eval_scores, where=reward_weight_double) reward_std += (reward_std == 0.0) * 1e-6 @@ -230,10 +267,14 @@ def training_epoch(training_state: TrainingState, policy_params = jax.tree_util.tree_map( lambda x, y: x + step_size * y / (top_directions * reward_std), - training_state.policy_params, noise) + training_state.policy_params, + noise, + ) - num_env_steps = training_state.num_env_steps + jnp.sum( - obs_weights, dtype=jnp.int32) * action_repeat + num_env_steps = ( + training_state.num_env_steps + + jnp.sum(obs_weights, dtype=jnp.int32) * action_repeat + ) metrics = { 'params_norm': optax.global_norm(policy_params), @@ -241,16 +282,21 @@ def training_epoch(training_state: TrainingState, 'eval_scores_std': jnp.std(eval_scores), 'weights': jnp.mean(reward_weight), } - return (TrainingState( # type: ignore # jnp-type - normalizer_params=normalizer_params, - policy_params=policy_params, - num_env_steps=num_env_steps), metrics) + return ( + TrainingState( # type: ignore # jnp-type + normalizer_params=normalizer_params, + policy_params=policy_params, + num_env_steps=num_env_steps, + ), + metrics, + ) - training_walltime = 0. + training_walltime = 0.0 # Note that this is NOT a pure jittable method. - def training_epoch_with_timing(training_state: TrainingState, - key: PRNGKey) -> Tuple[TrainingState, Metrics]: + def training_epoch_with_timing( + training_state: TrainingState, key: PRNGKey + ) -> Tuple[TrainingState, Metrics]: nonlocal training_walltime t = time.time() (training_state, metrics) = training_epoch(training_state, key) @@ -263,17 +309,19 @@ def training_epoch_with_timing(training_state: TrainingState, metrics = { 'training/sps': sps, 'training/walltime': training_walltime, - **{f'training/{name}': value for name, value in metrics.items()} + **{f'training/{name}': value for name, value in metrics.items()}, } return training_state, metrics # pytype: disable=bad-return-type # py311-upgrade normalizer_params = running_statistics.init_state( - specs.Array((obs_size,), jnp.dtype('float32'))) + specs.Array((obs_size,), jnp.dtype('float32')) + ) policy_params = ars_network.init(network_key) training_state = TrainingState( normalizer_params=normalizer_params, policy_params=policy_params, - num_env_steps=0) + num_env_steps=0, + ) if not eval_env: eval_env = environment @@ -296,19 +344,22 @@ def training_epoch_with_timing(training_state: TrainingState, num_eval_envs=num_eval_envs, episode_length=episode_length, action_repeat=action_repeat, - key=eval_key) + key=eval_key, + ) while training_state.num_env_steps < num_timesteps: # optimization key, epoch_key = jax.random.split(key) training_state, training_metrics = training_epoch_with_timing( - training_state, epoch_key) + training_state, epoch_key + ) if training_state.num_env_steps >= next_eval_step: # Run evals. metrics = evaluator.run_evaluation( (training_state.normalizer_params, training_state.policy_params), - training_metrics) + training_metrics, + ) logging.info(metrics) progress_fn(int(training_state.num_env_steps), metrics) next_eval_step += num_env_steps_between_evals diff --git a/brax/training/agents/ars/train_test.py b/brax/training/agents/ars/train_test.py index 09254172..96887070 100644 --- a/brax/training/agents/ars/train_test.py +++ b/brax/training/agents/ars/train_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Augmented Random Search training tests.""" + import pickle from absl.testing import absltest @@ -34,13 +35,14 @@ def testModelEncoding(self, normalize_observations): env, num_timesteps=128, episode_length=128, - normalize_observations=normalize_observations) + normalize_observations=normalize_observations, + ) normalize_fn = lambda x, y: x if normalize_observations: normalize_fn = running_statistics.normalize - ars_network = ars_networks.make_policy_network(env.observation_size, - env.action_size, - normalize_fn) + ars_network = ars_networks.make_policy_network( + env.observation_size, env.action_size, normalize_fn + ) inference = ars_networks.make_inference_fn(ars_network) byte_encoding = pickle.dumps(params) decoded_params = pickle.loads(byte_encoding) diff --git a/brax/training/agents/es/networks.py b/brax/training/agents/es/networks.py index 3035f98f..f8827744 100644 --- a/brax/training/agents/es/networks.py +++ b/brax/training/agents/es/networks.py @@ -33,16 +33,20 @@ class ESNetworks: def make_inference_fn(es_networks: ESNetworks): """Creates params and inference function for the ES agent.""" - def make_policy(params: types.PolicyParams, - deterministic: bool = False) -> types.Policy: + def make_policy( + params: types.PolicyParams, deterministic: bool = False + ) -> types.Policy: - def policy(observations: types.Observation, - key_sample: PRNGKey) -> Tuple[types.Action, types.Extra]: + def policy( + observations: types.Observation, key_sample: PRNGKey + ) -> Tuple[types.Action, types.Extra]: logits = es_networks.policy_network.apply(*params, observations) if deterministic: return es_networks.parametric_action_distribution.mode(logits), {} - return es_networks.parametric_action_distribution.sample( - logits, key_sample), {} + return ( + es_networks.parametric_action_distribution.sample(logits, key_sample), + {}, + ) return policy @@ -52,19 +56,22 @@ def policy(observations: types.Observation, def make_es_networks( observation_size: int, action_size: int, - preprocess_observations_fn: types.PreprocessObservationFn = types - .identity_observation_preprocessor, + preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, hidden_layer_sizes: Sequence[int] = (32,) * 4, - activation: networks.ActivationFn = linen.relu) -> ESNetworks: + activation: networks.ActivationFn = linen.relu, +) -> ESNetworks: """Make ES networks.""" parametric_action_distribution = distribution.NormalTanhDistribution( - event_size=action_size) + event_size=action_size + ) policy_network = networks.make_policy_network( parametric_action_distribution.param_size, observation_size, preprocess_observations_fn=preprocess_observations_fn, hidden_layer_sizes=hidden_layer_sizes, - activation=activation) + activation=activation, + ) return ESNetworks( policy_network=policy_network, - parametric_action_distribution=parametric_action_distribution) + parametric_action_distribution=parametric_action_distribution, + ) diff --git a/brax/training/agents/es/train.py b/brax/training/agents/es/train.py index b738842e..c9d91523 100644 --- a/brax/training/agents/es/train.py +++ b/brax/training/agents/es/train.py @@ -45,6 +45,7 @@ @flax.struct.dataclass class TrainingState: """Contains training state for the learner.""" + normalizer_params: running_statistics.RunningStatisticsState optimizer_state: optax.OptState policy_params: Params @@ -54,8 +55,8 @@ class TrainingState: # Centered rank from: https://arxiv.org/pdf/1703.03864.pdf def centered_rank(x: jnp.ndarray) -> jnp.ndarray: x = jnp.argsort(jnp.argsort(x)) - x /= (len(x) - 1) - return x - .5 + x /= len(x) - 1 + return x - 0.5 # Shaping from @@ -105,21 +106,26 @@ def train( process_count = jax.process_count() if process_count > 1: - raise ValueError('ES is not compatible with multiple hosts, ' - 'please use a single host device.') + raise ValueError( + 'ES is not compatible with multiple hosts, ' + 'please use a single host device.' + ) local_device_count = jax.local_device_count() local_devices_to_use = local_device_count if max_devices_per_host: local_devices_to_use = min(local_devices_to_use, max_devices_per_host) - logging.info('Local device count: %d, ' - 'devices to be used count: %d', local_device_count, - local_devices_to_use) + logging.info( + 'Local device count: %d, devices to be used count: %d', + local_device_count, + local_devices_to_use, + ) num_evals_after_init = max(num_evals - 1, 1) num_env_steps_between_evals = num_timesteps // num_evals_after_init - next_eval_step = num_timesteps - (num_evals_after_init - - 1) * num_env_steps_between_evals + next_eval_step = ( + num_timesteps - (num_evals_after_init - 1) * num_env_steps_between_evals + ) key = jax.random.PRNGKey(seed) key, network_key, eval_key, rng_key = jax.random.split(key, 4) @@ -155,39 +161,66 @@ def train( es_network = network_factory( observation_size=obs_size, action_size=env.action_size, - preprocess_observations_fn=normalize_fn) + preprocess_observations_fn=normalize_fn, + ) make_policy = es_networks.make_inference_fn(es_network) optimizer = optax.adam(learning_rate=learning_rate) vmapped_policy = jax.vmap( - es_network.policy_network.apply, in_axes=(None, 0, 0)) + es_network.policy_network.apply, in_axes=(None, 0, 0) + ) def run_step(carry, unused_target_t): - (env_state, policy_params, key, cumulative_reward, active_episode, - normalizer_params) = carry + ( + env_state, + policy_params, + key, + cumulative_reward, + active_episode, + normalizer_params, + ) = carry key, key_sample = jax.random.split(key) obs = env_state.obs logits = vmapped_policy(normalizer_params, policy_params, obs) actions = es_network.parametric_action_distribution.sample( - logits, key_sample) + logits, key_sample + ) nstate = env.step(env_state, actions) cumulative_reward = cumulative_reward + nstate.reward * active_episode new_active_episode = active_episode * (1 - nstate.done) - return (nstate, policy_params, key, cumulative_reward, new_active_episode, - normalizer_params), (env_state.obs, active_episode) - - def run_episode(normalizer_params: running_statistics.NestedMeanStd, - params: Params, key: PRNGKey): + return ( + nstate, + policy_params, + key, + cumulative_reward, + new_active_episode, + normalizer_params, + ), (env_state.obs, active_episode) + + def run_episode( + normalizer_params: running_statistics.NestedMeanStd, + params: Params, + key: PRNGKey, + ): key_scan, key_reset = jax.random.split(key) reset_keys = jax.random.split(key_reset, num_envs // local_devices_to_use) first_env_states = env.reset(reset_keys) cumulative_reward = first_env_states.reward active_episode = jnp.ones_like(cumulative_reward) (_, _, key, cumulative_reward, _, _), (obs, obs_weights) = jax.lax.scan( - run_step, (first_env_states, params, key_scan, cumulative_reward, - active_episode, normalizer_params), (), - length=episode_length // action_repeat) + run_step, + ( + first_env_states, + params, + key_scan, + cumulative_reward, + active_episode, + normalizer_params, + ), + (), + length=episode_length // action_repeat, + ) return cumulative_reward, obs, obs_weights def add_noise(params: Params, key: PRNGKey) -> Tuple[Params, Params, Params]: @@ -195,12 +228,16 @@ def add_noise(params: Params, key: PRNGKey) -> Tuple[Params, Params, Params]: treedef = jax.tree_util.tree_structure(params) all_keys = jax.random.split(key, num=num_vars) noise = jax.tree_util.tree_map( - lambda g, k: jax.random.normal(k, shape=g.shape, dtype=g.dtype), params, - jax.tree_util.tree_unflatten(treedef, all_keys)) - params_with_noise = jax.tree_util.tree_map(lambda g, n: g + n * perturbation_std, - params, noise) - params_with_anti_noise = jax.tree_util.tree_map(lambda g, n: g - n * perturbation_std, - params, noise) + lambda g, k: jax.random.normal(k, shape=g.shape, dtype=g.dtype), + params, + jax.tree_util.tree_unflatten(treedef, all_keys), + ) + params_with_noise = jax.tree_util.tree_map( + lambda g, n: g + n * perturbation_std, params, noise + ) + params_with_anti_noise = jax.tree_util.tree_map( + lambda g, n: g - n * perturbation_std, params, noise + ) return params_with_noise, params_with_anti_noise, noise prun_episode = jax.pmap(run_episode, in_axes=(None, 0, 0)) @@ -220,7 +257,6 @@ def compute_delta( weights: Fitness weights, vector of length population_size. Returns: - """ # NOTE: The trick "len(weights) -> len(weights) * perturbation_std" is # equivalent to tuning the l2_coef. @@ -234,55 +270,71 @@ def compute_delta( return -delta @jax.jit - def training_epoch(training_state: TrainingState, - key: PRNGKey) -> Tuple[TrainingState, Metrics]: + def training_epoch( + training_state: TrainingState, key: PRNGKey + ) -> Tuple[TrainingState, Metrics]: params = jax.tree_util.tree_map( lambda x: jnp.repeat( - jnp.expand_dims(x, axis=0), population_size, axis=0), - training_state.policy_params) + jnp.expand_dims(x, axis=0), population_size, axis=0 + ), + training_state.policy_params, + ) key, key_noise, key_es_eval = jax.random.split(key, 3) # generate perturbations params_with_noise, params_with_anti_noise, noise = add_noise( - params, key_noise) + params, key_noise + ) - pparams = jax.tree_util.tree_map(lambda a, b: jnp.concatenate([a, b], axis=0), - params_with_noise, params_with_anti_noise) + pparams = jax.tree_util.tree_map( + lambda a, b: jnp.concatenate([a, b], axis=0), + params_with_noise, + params_with_anti_noise, + ) pparams = jax.tree_util.tree_map( lambda x: jnp.reshape(x, (local_devices_to_use, -1) + x.shape[1:]), - pparams) + pparams, + ) key_es_eval = jax.random.split(key_es_eval, local_devices_to_use) eval_scores, obs, obs_weights = prun_episode( - training_state.normalizer_params, pparams, key_es_eval) + training_state.normalizer_params, pparams, key_es_eval + ) obs = jnp.reshape(obs, (-1,) + obs.shape[2:]) obs_weights = jnp.reshape(obs_weights, (-1,) + obs_weights.shape[2:]) normalizer_params = running_statistics.update( - training_state.normalizer_params, obs, weights=obs_weights) + training_state.normalizer_params, obs, weights=obs_weights + ) weights = jnp.reshape(eval_scores, [-1]) weights = fitness_shaping.value(weights) if center_fitness: - weights = (weights - jnp.mean(weights)) / (1E-6 + jnp.std(weights)) + weights = (weights - jnp.mean(weights)) / (1e-6 + jnp.std(weights)) weights1, weights2 = jnp.split(weights, 2) weights = weights1 - weights2 delta = jax.tree_util.tree_map( functools.partial(compute_delta, weights=weights), - training_state.policy_params, noise) + training_state.policy_params, + noise, + ) params_update, optimizer_state = optimizer.update( - delta, training_state.optimizer_state) - policy_params = optax.apply_updates(training_state.policy_params, - params_update) + delta, training_state.optimizer_state + ) + policy_params = optax.apply_updates( + training_state.policy_params, params_update + ) - num_env_steps = training_state.num_env_steps + jnp.sum( - obs_weights, dtype=jnp.int32) * action_repeat + num_env_steps = ( + training_state.num_env_steps + + jnp.sum(obs_weights, dtype=jnp.int32) * action_repeat + ) metrics = { 'params_norm': optax.global_norm(policy_params), @@ -290,17 +342,22 @@ def training_epoch(training_state: TrainingState, 'eval_scores_std': jnp.std(eval_scores), 'weights': jnp.mean(weights), } - return (TrainingState( # type: ignore # jnp-type - normalizer_params=normalizer_params, - optimizer_state=optimizer_state, - policy_params=policy_params, - num_env_steps=num_env_steps), metrics) + return ( + TrainingState( # type: ignore # jnp-type + normalizer_params=normalizer_params, + optimizer_state=optimizer_state, + policy_params=policy_params, + num_env_steps=num_env_steps, + ), + metrics, + ) - training_walltime = 0. + training_walltime = 0.0 # Note that this is NOT a pure jittable method. - def training_epoch_with_timing(training_state: TrainingState, - key: PRNGKey) -> Tuple[TrainingState, Metrics]: + def training_epoch_with_timing( + training_state: TrainingState, key: PRNGKey + ) -> Tuple[TrainingState, Metrics]: nonlocal training_walltime t = time.time() (training_state, metrics) = training_epoch(training_state, key) @@ -313,19 +370,21 @@ def training_epoch_with_timing(training_state: TrainingState, metrics = { 'training/sps': sps, 'training/walltime': training_walltime, - **{f'training/{name}': value for name, value in metrics.items()} + **{f'training/{name}': value for name, value in metrics.items()}, } return training_state, metrics # pytype: disable=bad-return-type # py311-upgrade normalizer_params = running_statistics.init_state( - specs.Array((obs_size,), jnp.dtype('float32'))) + specs.Array((obs_size,), jnp.dtype('float32')) + ) policy_params = es_network.policy_network.init(network_key) optimizer_state = optimizer.init(policy_params) training_state = TrainingState( normalizer_params=normalizer_params, optimizer_state=optimizer_state, policy_params=policy_params, - num_env_steps=0) + num_env_steps=0, + ) if not eval_env: eval_env = environment @@ -348,12 +407,14 @@ def training_epoch_with_timing(training_state: TrainingState, num_eval_envs=num_eval_envs, episode_length=episode_length, action_repeat=action_repeat, - key=eval_key) + key=eval_key, + ) if num_evals > 1: metrics = evaluator.run_evaluation( (training_state.normalizer_params, training_state.policy_params), - training_metrics={}) + training_metrics={}, + ) logging.info(metrics) progress_fn(0, metrics) @@ -361,13 +422,15 @@ def training_epoch_with_timing(training_state: TrainingState, # optimization key, epoch_key = jax.random.split(key) training_state, training_metrics = training_epoch_with_timing( - training_state, epoch_key) + training_state, epoch_key + ) if training_state.num_env_steps >= next_eval_step: # Run evals. metrics = evaluator.run_evaluation( (training_state.normalizer_params, training_state.policy_params), - training_metrics) + training_metrics, + ) logging.info(metrics) progress_fn(int(training_state.num_env_steps), metrics) next_eval_step += num_env_steps_between_evals diff --git a/brax/training/agents/es/train_test.py b/brax/training/agents/es/train_test.py index a7a42808..257dfaa3 100644 --- a/brax/training/agents/es/train_test.py +++ b/brax/training/agents/es/train_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Evolution Strategy training tests.""" + import pickle from absl.testing import absltest @@ -33,7 +34,8 @@ def testTrain(self): environment=envs.get_environment('fast'), num_timesteps=65536, episode_length=128, - learning_rate=0.1) + learning_rate=0.1, + ) self.assertGreater(metrics['eval/episode_reward'], 140) @parameterized.parameters(True, False) @@ -43,12 +45,14 @@ def testModelEncoding(self, normalize_observations): env, num_timesteps=128, episode_length=128, - normalize_observations=normalize_observations) + normalize_observations=normalize_observations, + ) normalize_fn = lambda x, y: x if normalize_observations: normalize_fn = running_statistics.normalize - es_network = es_networks.make_es_networks(env.observation_size, - env.action_size, normalize_fn) + es_network = es_networks.make_es_networks( + env.observation_size, env.action_size, normalize_fn + ) inference = es_networks.make_inference_fn(es_network) byte_encoding = pickle.dumps(params) decoded_params = pickle.loads(byte_encoding) diff --git a/brax/training/agents/ppo/losses.py b/brax/training/agents/ppo/losses.py index 8df1dea7..93d0b1e3 100644 --- a/brax/training/agents/ppo/losses.py +++ b/brax/training/agents/ppo/losses.py @@ -30,17 +30,20 @@ @flax.struct.dataclass class PPONetworkParams: """Contains training state for the learner.""" + policy: Params value: Params -def compute_gae(truncation: jnp.ndarray, - termination: jnp.ndarray, - rewards: jnp.ndarray, - values: jnp.ndarray, - bootstrap_value: jnp.ndarray, - lambda_: float = 1.0, - discount: float = 0.99): +def compute_gae( + truncation: jnp.ndarray, + termination: jnp.ndarray, + rewards: jnp.ndarray, + values: jnp.ndarray, + bootstrap_value: jnp.ndarray, + lambda_: float = 1.0, + discount: float = 0.99, +): """Calculates the Generalized Advantage Estimation (GAE). Args: @@ -65,7 +68,8 @@ def compute_gae(truncation: jnp.ndarray, truncation_mask = 1 - truncation # Append bootstrapped value to get [v1, ..., v_t+1] values_t_plus_1 = jnp.concatenate( - [values[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0) + [values[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0 + ) deltas = rewards + discount * (1 - termination) * values_t_plus_1 - values deltas *= truncation_mask @@ -79,17 +83,21 @@ def compute_vs_minus_v_xs(carry, target_t): return (lambda_, acc), (acc) (_, _), (vs_minus_v_xs) = jax.lax.scan( - compute_vs_minus_v_xs, (lambda_, acc), + compute_vs_minus_v_xs, + (lambda_, acc), (truncation_mask, deltas, termination), length=int(truncation_mask.shape[0]), - reverse=True) + reverse=True, + ) # Add V(x_s) to get v_s. vs = jnp.add(vs_minus_v_xs, values) vs_t_plus_1 = jnp.concatenate( - [vs[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0) - advantages = (rewards + discount * - (1 - termination) * vs_t_plus_1 - values) * truncation_mask + [vs[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0 + ) + advantages = ( + rewards + discount * (1 - termination) * vs_t_plus_1 - values + ) * truncation_mask return jax.lax.stop_gradient(vs), jax.lax.stop_gradient(advantages) @@ -104,7 +112,8 @@ def compute_ppo_loss( reward_scaling: float = 1.0, gae_lambda: float = 0.95, clipping_epsilon: float = 0.3, - normalize_advantage: bool = True) -> Tuple[jnp.ndarray, types.Metrics]: + normalize_advantage: bool = True, +) -> Tuple[jnp.ndarray, types.Metrics]: """Computes PPO loss. Args: @@ -112,7 +121,7 @@ def compute_ppo_loss( normalizer_params: Parameters of the normalizer. data: Transition that with leading dimension [B, T]. extra fields required are ['state_extras']['truncation'] ['policy_extras']['raw_action'] - ['policy_extras']['log_prob'] + ['policy_extras']['log_prob'] rng: Random key ppo_network: PPO networks. entropy_cost: entropy cost. @@ -131,8 +140,9 @@ def compute_ppo_loss( # Put the time dimension first. data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 0, 1), data) - policy_logits = policy_apply(normalizer_params, params.policy, - data.observation) + policy_logits = policy_apply( + normalizer_params, params.policy, data.observation + ) baseline = value_apply(normalizer_params, params.value, data.observation) terminal_obs = jax.tree_util.tree_map(lambda x: x[-1], data.next_observation) @@ -143,7 +153,8 @@ def compute_ppo_loss( termination = (1 - data.discount) * (1 - truncation) target_action_log_probs = parametric_action_distribution.log_prob( - policy_logits, data.extras['policy_extras']['raw_action']) + policy_logits, data.extras['policy_extras']['raw_action'] + ) behaviour_action_log_probs = data.extras['policy_extras']['log_prob'] vs, advantages = compute_gae( @@ -153,14 +164,16 @@ def compute_ppo_loss( values=baseline, bootstrap_value=bootstrap_value, lambda_=gae_lambda, - discount=discounting) + discount=discounting, + ) if normalize_advantage: advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) rho_s = jnp.exp(target_action_log_probs - behaviour_action_log_probs) surrogate_loss1 = rho_s * advantages - surrogate_loss2 = jnp.clip(rho_s, 1 - clipping_epsilon, - 1 + clipping_epsilon) * advantages + surrogate_loss2 = ( + jnp.clip(rho_s, 1 - clipping_epsilon, 1 + clipping_epsilon) * advantages + ) policy_loss = -jnp.mean(jnp.minimum(surrogate_loss1, surrogate_loss2)) @@ -177,5 +190,5 @@ def compute_ppo_loss( 'total_loss': total_loss, 'policy_loss': policy_loss, 'v_loss': v_loss, - 'entropy_loss': entropy_loss + 'entropy_loss': entropy_loss, } diff --git a/brax/training/agents/ppo/networks.py b/brax/training/agents/ppo/networks.py index 76336c22..1922bad3 100644 --- a/brax/training/agents/ppo/networks.py +++ b/brax/training/agents/ppo/networks.py @@ -40,20 +40,23 @@ def make_policy( policy_network = ppo_networks.policy_network parametric_action_distribution = ppo_networks.parametric_action_distribution - def policy(observations: types.Observation, - key_sample: PRNGKey) -> Tuple[types.Action, types.Extra]: + def policy( + observations: types.Observation, key_sample: PRNGKey + ) -> Tuple[types.Action, types.Extra]: param_subset = (params[0], params[1]) # normalizer and policy params logits = policy_network.apply(*param_subset, observations) if deterministic: return ppo_networks.parametric_action_distribution.mode(logits), {} raw_actions = parametric_action_distribution.sample_no_postprocessing( - logits, key_sample) + logits, key_sample + ) log_prob = parametric_action_distribution.log_prob(logits, raw_actions) postprocessed_actions = parametric_action_distribution.postprocess( - raw_actions) + raw_actions + ) return postprocessed_actions, { 'log_prob': log_prob, - 'raw_action': raw_actions + 'raw_action': raw_actions, } return policy @@ -64,31 +67,35 @@ def policy(observations: types.Observation, def make_ppo_networks( observation_size: int, action_size: int, - preprocess_observations_fn: types.PreprocessObservationFn = types - .identity_observation_preprocessor, + preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, policy_hidden_layer_sizes: Sequence[int] = (32,) * 4, value_hidden_layer_sizes: Sequence[int] = (256,) * 5, activation: networks.ActivationFn = linen.swish, policy_obs_key: str = 'state', - value_obs_key: str = 'state') -> PPONetworks: + value_obs_key: str = 'state', +) -> PPONetworks: """Make PPO networks with preprocessor.""" parametric_action_distribution = distribution.NormalTanhDistribution( - event_size=action_size) + event_size=action_size + ) policy_network = networks.make_policy_network( parametric_action_distribution.param_size, observation_size, preprocess_observations_fn=preprocess_observations_fn, hidden_layer_sizes=policy_hidden_layer_sizes, activation=activation, - obs_key=policy_obs_key) + obs_key=policy_obs_key, + ) value_network = networks.make_value_network( observation_size, preprocess_observations_fn=preprocess_observations_fn, hidden_layer_sizes=value_hidden_layer_sizes, activation=activation, - obs_key=value_obs_key) + obs_key=value_obs_key, + ) return PPONetworks( policy_network=policy_network, value_network=value_network, - parametric_action_distribution=parametric_action_distribution) + parametric_action_distribution=parametric_action_distribution, + ) diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index 450b8837..c7295ac7 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -54,6 +54,7 @@ @flax.struct.dataclass class TrainingState: """Contains training state for the learner.""" + optimizer_state: optax.OptState params: ppo_losses.PPONetworkParams normalizer_params: running_statistics.RunningStatisticsState @@ -70,6 +71,7 @@ def _strip_weak_type(tree): def f(leaf): leaf = jnp.asarray(leaf) return leaf.astype(leaf.dtype) + return jax.tree_util.tree_map(f, tree) @@ -260,13 +262,19 @@ def train( local_devices_to_use = min(local_devices_to_use, max_devices_per_host) logging.info( 'Device count: %d, process count: %d (id %d), local device count: %d, ' - 'devices to be used count: %d', jax.device_count(), process_count, - process_id, local_device_count, local_devices_to_use) + 'devices to be used count: %d', + jax.device_count(), + process_count, + process_id, + local_device_count, + local_devices_to_use, + ) device_count = local_devices_to_use * process_count # The number of environment steps executed for every training step. env_step_per_training_step = ( - batch_size * unroll_length * num_minibatches * action_repeat) + batch_size * unroll_length * num_minibatches * action_repeat + ) num_evals_after_init = max(num_evals - 1, 1) # The number of training_step calls per training_epoch call. # equals to ceil(num_timesteps / (num_evals * env_step_per_training_step * @@ -315,8 +323,9 @@ def train( reset_fn = jax.jit(jax.vmap(env.reset)) key_envs = jax.random.split(key_env, num_envs // process_count) - key_envs = jnp.reshape(key_envs, - (local_devices_to_use, -1) + key_envs.shape[1:]) + key_envs = jnp.reshape( + key_envs, (local_devices_to_use, -1) + key_envs.shape[1:] + ) env_state = reset_fn(key_envs) # Discard the batch axes over devices and envs. obs_shape = jax.tree_util.tree_map(lambda x: x.shape[2:], env_state.obs) @@ -325,9 +334,8 @@ def train( if normalize_observations: normalize = running_statistics.normalize ppo_network = network_factory( - obs_shape, - env.action_size, - preprocess_observations_fn=normalize) + obs_shape, env.action_size, preprocess_observations_fn=normalize + ) make_policy = ppo_networks.make_inference_fn(ppo_network) optimizer = optax.adam(learning_rate=learning_rate) @@ -335,7 +343,7 @@ def train( # TODO: Move gradient clipping to `training/gradients.py`. optimizer = optax.chain( optax.clip_by_global_norm(max_grad_norm), - optax.adam(learning_rate=learning_rate) + optax.adam(learning_rate=learning_rate), ) loss_fn = functools.partial( @@ -346,14 +354,18 @@ def train( reward_scaling=reward_scaling, gae_lambda=gae_lambda, clipping_epsilon=clipping_epsilon, - normalize_advantage=normalize_advantage) + normalize_advantage=normalize_advantage, + ) gradient_update_fn = gradients.gradient_update_fn( - loss_fn, optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True) + loss_fn, optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True + ) def minibatch_step( - carry, data: types.Transition, - normalizer_params: running_statistics.RunningStatisticsState): + carry, + data: types.Transition, + normalizer_params: running_statistics.RunningStatisticsState, + ): optimizer_state, params, key = carry key, key_loss = jax.random.split(key) (_, metrics), params, optimizer_state = gradient_update_fn( @@ -361,12 +373,17 @@ def minibatch_step( normalizer_params, data, key_loss, - optimizer_state=optimizer_state) + optimizer_state=optimizer_state, + ) return (optimizer_state, params, key), metrics - def sgd_step(carry, unused_t, data: types.Transition, - normalizer_params: running_statistics.RunningStatisticsState): + def sgd_step( + carry, + unused_t, + data: types.Transition, + normalizer_params: running_statistics.RunningStatisticsState, + ): optimizer_state, params, key = carry key, key_perm, key_grad = jax.random.split(key, 3) @@ -392,12 +409,13 @@ def convert_data(x: jnp.ndarray): functools.partial(minibatch_step, normalizer_params=normalizer_params), (optimizer_state, params, key_grad), shuffled_data, - length=num_minibatches) + length=num_minibatches, + ) return (optimizer_state, params, key), metrics def training_step( - carry: Tuple[TrainingState, envs.State, PRNGKey], - unused_t) -> Tuple[Tuple[TrainingState, envs.State, PRNGKey], Metrics]: + carry: Tuple[TrainingState, envs.State, PRNGKey], unused_t + ) -> Tuple[Tuple[TrainingState, envs.State, PRNGKey], Metrics]: training_state, state, key = carry key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3) @@ -416,16 +434,21 @@ def f(carry, unused_t): policy, current_key, unroll_length, - extra_fields=('truncation',)) + extra_fields=('truncation',), + ) return (next_state, next_key), data (state, _), data = jax.lax.scan( - f, (state, key_generate_unroll), (), - length=batch_size * num_minibatches // num_envs) + f, + (state, key_generate_unroll), + (), + length=batch_size * num_minibatches // num_envs, + ) # Have leading dimensions (batch_size * num_minibatches, unroll_length) data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 1, 2), data) - data = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), - data) + data = jax.tree_util.tree_map( + lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), data + ) assert data.discount.shape[1:] == (unroll_length,) # Update normalization params and normalize observations. @@ -437,22 +460,30 @@ def f(carry, unused_t): (optimizer_state, params, _), metrics = jax.lax.scan( functools.partial( - sgd_step, data=data, normalizer_params=normalizer_params), - (training_state.optimizer_state, training_state.params, key_sgd), (), - length=num_updates_per_batch) + sgd_step, data=data, normalizer_params=normalizer_params + ), + (training_state.optimizer_state, training_state.params, key_sgd), + (), + length=num_updates_per_batch, + ) new_training_state = TrainingState( optimizer_state=optimizer_state, params=params, normalizer_params=normalizer_params, - env_steps=training_state.env_steps + env_step_per_training_step) + env_steps=training_state.env_steps + env_step_per_training_step, + ) return (new_training_state, state, new_key), metrics - def training_epoch(training_state: TrainingState, state: envs.State, - key: PRNGKey) -> Tuple[TrainingState, envs.State, Metrics]: + def training_epoch( + training_state: TrainingState, state: envs.State, key: PRNGKey + ) -> Tuple[TrainingState, envs.State, Metrics]: (training_state, state, _), loss_metrics = jax.lax.scan( - training_step, (training_state, state, key), (), - length=num_training_steps_per_epoch) + training_step, + (training_state, state, key), + (), + length=num_training_steps_per_epoch, + ) loss_metrics = jax.tree_util.tree_map(jnp.mean, loss_metrics) return training_state, state, loss_metrics @@ -460,8 +491,8 @@ def training_epoch(training_state: TrainingState, state: envs.State, # Note that this is NOT a pure jittable method. def training_epoch_with_timing( - training_state: TrainingState, env_state: envs.State, - key: PRNGKey) -> Tuple[TrainingState, envs.State, Metrics]: + training_state: TrainingState, env_state: envs.State, key: PRNGKey + ) -> Tuple[TrainingState, envs.State, Metrics]: nonlocal training_walltime t = time.time() training_state, env_state = _strip_weak_type((training_state, env_state)) @@ -473,13 +504,15 @@ def training_epoch_with_timing( epoch_training_time = time.time() - t training_walltime += epoch_training_time - sps = (num_training_steps_per_epoch * - env_step_per_training_step * - max(num_resets_per_eval, 1)) / epoch_training_time + sps = ( + num_training_steps_per_epoch + * env_step_per_training_step + * max(num_resets_per_eval, 1) + ) / epoch_training_time metrics = { 'training/sps': sps, 'training/walltime': training_walltime, - **{f'training/{name}': value for name, value in metrics.items()} + **{f'training/{name}': value for name, value in metrics.items()}, } return training_state, env_state, metrics # pytype: disable=bad-return-type # py311-upgrade @@ -527,8 +560,8 @@ def training_epoch_with_timing( ) training_state = jax.device_put_replicated( - training_state, - jax.local_devices()[:local_devices_to_use]) + training_state, jax.local_devices()[:local_devices_to_use] + ) if not eval_env: eval_env = environment @@ -550,7 +583,8 @@ def training_epoch_with_timing( num_eval_envs=num_eval_envs, episode_length=episode_length, action_repeat=action_repeat, - key=eval_key) + key=eval_key, + ) # Run initial eval metrics = {} @@ -582,8 +616,8 @@ def training_epoch_with_timing( current_step = int(_unpmap(training_state.env_steps)) key_envs = jax.vmap( - lambda x, s: jax.random.split(x[0], s), - in_axes=(0, None))(key_envs, key_envs.shape[1]) + lambda x, s: jax.random.split(x[0], s), in_axes=(0, None) + )(key_envs, key_envs.shape[1]) # TODO: move extra reset logic to the AutoResetWrapper. env_state = reset_fn(key_envs) if num_resets_per_eval > 0 else env_state diff --git a/brax/training/agents/ppo/train_test.py b/brax/training/agents/ppo/train_test.py index c733fc4c..58111591 100644 --- a/brax/training/agents/ppo/train_test.py +++ b/brax/training/agents/ppo/train_test.py @@ -13,6 +13,7 @@ # limitations under the License. """PPO tests.""" + import functools import pickle from absl.testing import absltest @@ -29,10 +30,10 @@ class PPOTest(parameterized.TestCase): """Tests for PPO module.""" - @parameterized.parameters("ndarray", "dict_state") + @parameterized.parameters('ndarray', 'dict_state') def testTrain(self, obs_mode): """Test PPO with a simple env.""" - fast = envs.get_environment("fast", obs_mode=obs_mode) + fast = envs.get_environment('fast', obs_mode=obs_mode) _, _, metrics = ppo.train( fast, num_timesteps=2**15, @@ -49,7 +50,8 @@ def testTrain(self, obs_mode): seed=2, num_evals=3, reward_scaling=10, - normalize_advantage=False) + normalize_advantage=False, + ) self.assertGreater(metrics['eval/episode_reward'], 135) self.assertEqual(fast.reset_count, 2) # type: ignore self.assertEqual(fast.step_count, 2) # type: ignore @@ -76,7 +78,9 @@ def testTrainV2(self): def testTrainAsymmetricActorCritic(self): """Test PPO with asymmetric actor critic.""" - env = envs.get_environment('fast', asymmetric_obs=True, obs_mode='dict_state') + env = envs.get_environment( + 'fast', asymmetric_obs=True, obs_mode='dict_state' + ) network_factory = functools.partial( ppo_networks.make_ppo_networks, @@ -122,12 +126,14 @@ def testNetworkEncoding(self, normalize_observations): num_timesteps=128, episode_length=128, num_envs=128, - normalize_observations=normalize_observations) + normalize_observations=normalize_observations, + ) normalize_fn = lambda x, y: x if normalize_observations: normalize_fn = running_statistics.normalize - ppo_network = ppo_networks.make_ppo_networks(env.observation_size, - env.action_size, normalize_fn) + ppo_network = ppo_networks.make_ppo_networks( + env.observation_size, env.action_size, normalize_fn + ) inference = ppo_networks.make_inference_fn(ppo_network) byte_encoding = pickle.dumps(params) decoded_params = pickle.loads(byte_encoding) @@ -135,7 +141,8 @@ def testNetworkEncoding(self, normalize_observations): # Compute one action. state = env.reset(jax.random.PRNGKey(0)) original_action = original_inference(decoded_params)( - state.obs, jax.random.PRNGKey(0))[0] + state.obs, jax.random.PRNGKey(0) + )[0] action = inference(decoded_params)(state.obs, jax.random.PRNGKey(0))[0] self.assertSequenceEqual(original_action, action) env.step(state, action) diff --git a/brax/training/agents/sac/losses.py b/brax/training/agents/sac/losses.py index d92468e8..36a40427 100644 --- a/brax/training/agents/sac/losses.py +++ b/brax/training/agents/sac/losses.py @@ -16,6 +16,7 @@ See: https://arxiv.org/pdf/1812.05905.pdf """ + from typing import Any from brax.training import types @@ -28,8 +29,12 @@ Transition = types.Transition -def make_losses(sac_network: sac_networks.SACNetworks, reward_scaling: float, - discounting: float, action_size: int): +def make_losses( + sac_network: sac_networks.SACNetworks, + reward_scaling: float, + discounting: float, + action_size: int, +): """Creates the SAC losses.""" target_entropy = -0.5 * action_size @@ -37,38 +42,58 @@ def make_losses(sac_network: sac_networks.SACNetworks, reward_scaling: float, q_network = sac_network.q_network parametric_action_distribution = sac_network.parametric_action_distribution - def alpha_loss(log_alpha: jnp.ndarray, policy_params: Params, - normalizer_params: Any, transitions: Transition, - key: PRNGKey) -> jnp.ndarray: + def alpha_loss( + log_alpha: jnp.ndarray, + policy_params: Params, + normalizer_params: Any, + transitions: Transition, + key: PRNGKey, + ) -> jnp.ndarray: """Eq 18 from https://arxiv.org/pdf/1812.05905.pdf.""" - dist_params = policy_network.apply(normalizer_params, policy_params, - transitions.observation) + dist_params = policy_network.apply( + normalizer_params, policy_params, transitions.observation + ) action = parametric_action_distribution.sample_no_postprocessing( - dist_params, key) + dist_params, key + ) log_prob = parametric_action_distribution.log_prob(dist_params, action) alpha = jnp.exp(log_alpha) alpha_loss = alpha * jax.lax.stop_gradient(-log_prob - target_entropy) return jnp.mean(alpha_loss) - def critic_loss(q_params: Params, policy_params: Params, - normalizer_params: Any, target_q_params: Params, - alpha: jnp.ndarray, transitions: Transition, - key: PRNGKey) -> jnp.ndarray: - q_old_action = q_network.apply(normalizer_params, q_params, - transitions.observation, transitions.action) - next_dist_params = policy_network.apply(normalizer_params, policy_params, - transitions.next_observation) + def critic_loss( + q_params: Params, + policy_params: Params, + normalizer_params: Any, + target_q_params: Params, + alpha: jnp.ndarray, + transitions: Transition, + key: PRNGKey, + ) -> jnp.ndarray: + q_old_action = q_network.apply( + normalizer_params, q_params, transitions.observation, transitions.action + ) + next_dist_params = policy_network.apply( + normalizer_params, policy_params, transitions.next_observation + ) next_action = parametric_action_distribution.sample_no_postprocessing( - next_dist_params, key) + next_dist_params, key + ) next_log_prob = parametric_action_distribution.log_prob( - next_dist_params, next_action) + next_dist_params, next_action + ) next_action = parametric_action_distribution.postprocess(next_action) - next_q = q_network.apply(normalizer_params, target_q_params, - transitions.next_observation, next_action) + next_q = q_network.apply( + normalizer_params, + target_q_params, + transitions.next_observation, + next_action, + ) next_v = jnp.min(next_q, axis=-1) - alpha * next_log_prob - target_q = jax.lax.stop_gradient(transitions.reward * reward_scaling + - transitions.discount * discounting * - next_v) + target_q = jax.lax.stop_gradient( + transitions.reward * reward_scaling + + transitions.discount * discounting * next_v + ) q_error = q_old_action - jnp.expand_dims(target_q, -1) # Better bootstrapping for truncated episodes. @@ -78,17 +103,25 @@ def critic_loss(q_params: Params, policy_params: Params, q_loss = 0.5 * jnp.mean(jnp.square(q_error)) return q_loss - def actor_loss(policy_params: Params, normalizer_params: Any, - q_params: Params, alpha: jnp.ndarray, transitions: Transition, - key: PRNGKey) -> jnp.ndarray: - dist_params = policy_network.apply(normalizer_params, policy_params, - transitions.observation) + def actor_loss( + policy_params: Params, + normalizer_params: Any, + q_params: Params, + alpha: jnp.ndarray, + transitions: Transition, + key: PRNGKey, + ) -> jnp.ndarray: + dist_params = policy_network.apply( + normalizer_params, policy_params, transitions.observation + ) action = parametric_action_distribution.sample_no_postprocessing( - dist_params, key) + dist_params, key + ) log_prob = parametric_action_distribution.log_prob(dist_params, action) action = parametric_action_distribution.postprocess(action) - q_action = q_network.apply(normalizer_params, q_params, - transitions.observation, action) + q_action = q_network.apply( + normalizer_params, q_params, transitions.observation, action + ) min_q = jnp.min(q_action, axis=-1) actor_loss = alpha * log_prob - min_q return jnp.mean(actor_loss) diff --git a/brax/training/agents/sac/networks.py b/brax/training/agents/sac/networks.py index dc50106a..56b819e6 100644 --- a/brax/training/agents/sac/networks.py +++ b/brax/training/agents/sac/networks.py @@ -34,16 +34,22 @@ class SACNetworks: def make_inference_fn(sac_networks: SACNetworks): """Creates params and inference function for the SAC agent.""" - def make_policy(params: types.PolicyParams, - deterministic: bool = False) -> types.Policy: + def make_policy( + params: types.PolicyParams, deterministic: bool = False + ) -> types.Policy: - def policy(observations: types.Observation, - key_sample: PRNGKey) -> Tuple[types.Action, types.Extra]: + def policy( + observations: types.Observation, key_sample: PRNGKey + ) -> Tuple[types.Action, types.Extra]: logits = sac_networks.policy_network.apply(*params, observations) if deterministic: return sac_networks.parametric_action_distribution.mode(logits), {} - return sac_networks.parametric_action_distribution.sample( - logits, key_sample), {} + return ( + sac_networks.parametric_action_distribution.sample( + logits, key_sample + ), + {}, + ) return policy @@ -53,30 +59,34 @@ def policy(observations: types.Observation, def make_sac_networks( observation_size: int, action_size: int, - preprocess_observations_fn: types.PreprocessObservationFn = types - .identity_observation_preprocessor, + preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, hidden_layer_sizes: Sequence[int] = (256, 256), activation: networks.ActivationFn = linen.relu, policy_network_layer_norm: bool = False, - q_network_layer_norm: bool = False) -> SACNetworks: + q_network_layer_norm: bool = False, +) -> SACNetworks: """Make SAC networks.""" parametric_action_distribution = distribution.NormalTanhDistribution( - event_size=action_size) + event_size=action_size + ) policy_network = networks.make_policy_network( parametric_action_distribution.param_size, observation_size, preprocess_observations_fn=preprocess_observations_fn, hidden_layer_sizes=hidden_layer_sizes, activation=activation, - layer_norm=policy_network_layer_norm) + layer_norm=policy_network_layer_norm, + ) q_network = networks.make_q_network( observation_size, action_size, preprocess_observations_fn=preprocess_observations_fn, hidden_layer_sizes=hidden_layer_sizes, activation=activation, - layer_norm=q_network_layer_norm) + layer_norm=q_network_layer_norm, + ) return SACNetworks( policy_network=policy_network, q_network=q_network, - parametric_action_distribution=parametric_action_distribution) + parametric_action_distribution=parametric_action_distribution, + ) diff --git a/brax/training/agents/sac/train.py b/brax/training/agents/sac/train.py index e4b3c2ff..82a4b927 100644 --- a/brax/training/agents/sac/train.py +++ b/brax/training/agents/sac/train.py @@ -54,6 +54,7 @@ @flax.struct.dataclass class TrainingState: """Contains training state for the learner.""" + policy_optimizer_state: optax.OptState policy_params: Params q_optimizer_state: optax.OptState @@ -71,14 +72,17 @@ def _unpmap(v): def _init_training_state( - key: PRNGKey, obs_size: int, local_devices_to_use: int, + key: PRNGKey, + obs_size: int, + local_devices_to_use: int, sac_network: sac_networks.SACNetworks, alpha_optimizer: optax.GradientTransformation, policy_optimizer: optax.GradientTransformation, - q_optimizer: optax.GradientTransformation) -> TrainingState: + q_optimizer: optax.GradientTransformation, +) -> TrainingState: """Inits the training state and replicates it over devices.""" key_policy, key_q = jax.random.split(key) - log_alpha = jnp.asarray(0., dtype=jnp.float32) + log_alpha = jnp.asarray(0.0, dtype=jnp.float32) alpha_optimizer_state = alpha_optimizer.init(log_alpha) policy_params = sac_network.policy_network.init(key_policy) @@ -87,7 +91,8 @@ def _init_training_state( q_optimizer_state = q_optimizer.init(q_params) normalizer_params = running_statistics.init_state( - specs.Array((obs_size,), jnp.dtype('float32'))) + specs.Array((obs_size,), jnp.dtype('float32')) + ) training_state = TrainingState( policy_optimizer_state=policy_optimizer_state, @@ -99,9 +104,11 @@ def _init_training_state( env_steps=jnp.zeros(()), alpha_optimizer_state=alpha_optimizer_state, alpha_params=log_alpha, - normalizer_params=normalizer_params) - return jax.device_put_replicated(training_state, - jax.local_devices()[:local_devices_to_use]) + normalizer_params=normalizer_params, + ) + return jax.device_put_replicated( + training_state, jax.local_devices()[:local_devices_to_use] + ) def train( @@ -141,12 +148,16 @@ def train( if max_devices_per_host is not None: local_devices_to_use = min(local_devices_to_use, max_devices_per_host) device_count = local_devices_to_use * jax.process_count() - logging.info('local_device_count: %s; total_device_count: %s', - local_devices_to_use, device_count) + logging.info( + 'local_device_count: %s; total_device_count: %s', + local_devices_to_use, + device_count, + ) if min_replay_size >= num_timesteps: raise ValueError( - 'No training will happen because min_replay_size >= num_timesteps') + 'No training will happen because min_replay_size >= num_timesteps' + ) if max_replay_size is None: max_replay_size = num_timesteps @@ -163,8 +174,9 @@ def train( # ceil(num_timesteps - num_prefill_env_steps / # (num_evals_after_init * env_steps_per_actor_step)) num_training_steps_per_epoch = -( - -(num_timesteps - num_prefill_env_steps) // - (num_evals_after_init * env_steps_per_actor_step)) + -(num_timesteps - num_prefill_env_steps) + // (num_evals_after_init * env_steps_per_actor_step) + ) assert num_envs % device_count == 0 env = environment @@ -181,7 +193,8 @@ def train( v_randomization_fn = functools.partial( randomization_fn, rng=jax.random.split( - key, num_envs // jax.process_count() // local_devices_to_use), + key, num_envs // jax.process_count() // local_devices_to_use + ), ) env = wrap_for_training( env, @@ -201,7 +214,8 @@ def train( sac_network = network_factory( observation_size=obs_size, action_size=action_size, - preprocess_observations_fn=normalize_fn) + preprocess_observations_fn=normalize_fn, + ) make_policy = sac_networks.make_inference_fn(sac_network) alpha_optimizer = optax.adam(learning_rate=3e-4) @@ -214,35 +228,36 @@ def train( dummy_transition = Transition( # pytype: disable=wrong-arg-types # jax-ndarray observation=dummy_obs, action=dummy_action, - reward=0., - discount=0., + reward=0.0, + discount=0.0, next_observation=dummy_obs, - extras={ - 'state_extras': { - 'truncation': 0. - }, - 'policy_extras': {} - }) + extras={'state_extras': {'truncation': 0.0}, 'policy_extras': {}}, + ) replay_buffer = replay_buffers.UniformSamplingQueue( max_replay_size=max_replay_size // device_count, dummy_data_sample=dummy_transition, - sample_batch_size=batch_size * grad_updates_per_step // device_count) + sample_batch_size=batch_size * grad_updates_per_step // device_count, + ) alpha_loss, critic_loss, actor_loss = sac_losses.make_losses( sac_network=sac_network, reward_scaling=reward_scaling, discounting=discounting, - action_size=action_size) + action_size=action_size, + ) alpha_update = gradients.gradient_update_fn( # pytype: disable=wrong-arg-types # jax-ndarray - alpha_loss, alpha_optimizer, pmap_axis_name=_PMAP_AXIS_NAME) + alpha_loss, alpha_optimizer, pmap_axis_name=_PMAP_AXIS_NAME + ) critic_update = gradients.gradient_update_fn( # pytype: disable=wrong-arg-types # jax-ndarray - critic_loss, q_optimizer, pmap_axis_name=_PMAP_AXIS_NAME) + critic_loss, q_optimizer, pmap_axis_name=_PMAP_AXIS_NAME + ) actor_update = gradients.gradient_update_fn( # pytype: disable=wrong-arg-types # jax-ndarray - actor_loss, policy_optimizer, pmap_axis_name=_PMAP_AXIS_NAME) + actor_loss, policy_optimizer, pmap_axis_name=_PMAP_AXIS_NAME + ) def sgd_step( - carry: Tuple[TrainingState, PRNGKey], - transitions: Transition) -> Tuple[Tuple[TrainingState, PRNGKey], Metrics]: + carry: Tuple[TrainingState, PRNGKey], transitions: Transition + ) -> Tuple[Tuple[TrainingState, PRNGKey], Metrics]: training_state, key = carry key, key_alpha, key_critic, key_actor = jax.random.split(key, 4) @@ -253,7 +268,8 @@ def sgd_step( training_state.normalizer_params, transitions, key_alpha, - optimizer_state=training_state.alpha_optimizer_state) + optimizer_state=training_state.alpha_optimizer_state, + ) alpha = jnp.exp(training_state.alpha_params) critic_loss, q_params, q_optimizer_state = critic_update( training_state.q_params, @@ -263,7 +279,8 @@ def sgd_step( alpha, transitions, key_critic, - optimizer_state=training_state.q_optimizer_state) + optimizer_state=training_state.q_optimizer_state, + ) actor_loss, policy_params, policy_optimizer_state = actor_update( training_state.policy_params, training_state.normalizer_params, @@ -271,11 +288,14 @@ def sgd_step( alpha, transitions, key_actor, - optimizer_state=training_state.policy_optimizer_state) + optimizer_state=training_state.policy_optimizer_state, + ) new_target_q_params = jax.tree_util.tree_map( - lambda x, y: x * (1 - tau) + y * tau, training_state.target_q_params, - q_params) + lambda x, y: x * (1 - tau) + y * tau, + training_state.target_q_params, + q_params, + ) metrics = { 'critic_loss': critic_loss, @@ -294,56 +314,78 @@ def sgd_step( env_steps=training_state.env_steps, alpha_optimizer_state=alpha_optimizer_state, alpha_params=alpha_params, - normalizer_params=training_state.normalizer_params) + normalizer_params=training_state.normalizer_params, + ) return (new_training_state, key), metrics def get_experience( normalizer_params: running_statistics.RunningStatisticsState, - policy_params: Params, env_state: Union[envs.State, envs_v1.State], - buffer_state: ReplayBufferState, key: PRNGKey - ) -> Tuple[running_statistics.RunningStatisticsState, - Union[envs.State, envs_v1.State], ReplayBufferState]: + policy_params: Params, + env_state: Union[envs.State, envs_v1.State], + buffer_state: ReplayBufferState, + key: PRNGKey, + ) -> Tuple[ + running_statistics.RunningStatisticsState, + Union[envs.State, envs_v1.State], + ReplayBufferState, + ]: policy = make_policy((normalizer_params, policy_params)) env_state, transitions = acting.actor_step( - env, env_state, policy, key, extra_fields=('truncation',)) + env, env_state, policy, key, extra_fields=('truncation',) + ) normalizer_params = running_statistics.update( normalizer_params, transitions.observation, - pmap_axis_name=_PMAP_AXIS_NAME) + pmap_axis_name=_PMAP_AXIS_NAME, + ) buffer_state = replay_buffer.insert(buffer_state, transitions) return normalizer_params, env_state, buffer_state def training_step( - training_state: TrainingState, env_state: envs.State, - buffer_state: ReplayBufferState, key: PRNGKey - ) -> Tuple[TrainingState, Union[envs.State, envs_v1.State], - ReplayBufferState, Metrics]: + training_state: TrainingState, + env_state: envs.State, + buffer_state: ReplayBufferState, + key: PRNGKey, + ) -> Tuple[ + TrainingState, + Union[envs.State, envs_v1.State], + ReplayBufferState, + Metrics, + ]: experience_key, training_key = jax.random.split(key) normalizer_params, env_state, buffer_state = get_experience( - training_state.normalizer_params, training_state.policy_params, - env_state, buffer_state, experience_key) + training_state.normalizer_params, + training_state.policy_params, + env_state, + buffer_state, + experience_key, + ) training_state = training_state.replace( normalizer_params=normalizer_params, - env_steps=training_state.env_steps + env_steps_per_actor_step) + env_steps=training_state.env_steps + env_steps_per_actor_step, + ) buffer_state, transitions = replay_buffer.sample(buffer_state) # Change the front dimension of transitions so 'update_step' is called # grad_updates_per_step times by the scan. transitions = jax.tree_util.tree_map( lambda x: jnp.reshape(x, (grad_updates_per_step, -1) + x.shape[1:]), - transitions) - (training_state, _), metrics = jax.lax.scan(sgd_step, - (training_state, training_key), - transitions) + transitions, + ) + (training_state, _), metrics = jax.lax.scan( + sgd_step, (training_state, training_key), transitions + ) metrics['buffer_current_size'] = replay_buffer.size(buffer_state) return training_state, env_state, buffer_state, metrics def prefill_replay_buffer( - training_state: TrainingState, env_state: envs.State, - buffer_state: ReplayBufferState, key: PRNGKey + training_state: TrainingState, + env_state: envs.State, + buffer_state: ReplayBufferState, + key: PRNGKey, ) -> Tuple[TrainingState, envs.State, ReplayBufferState, PRNGKey]: def f(carry, unused): @@ -351,23 +393,34 @@ def f(carry, unused): training_state, env_state, buffer_state, key = carry key, new_key = jax.random.split(key) new_normalizer_params, env_state, buffer_state = get_experience( - training_state.normalizer_params, training_state.policy_params, - env_state, buffer_state, key) + training_state.normalizer_params, + training_state.policy_params, + env_state, + buffer_state, + key, + ) new_training_state = training_state.replace( normalizer_params=new_normalizer_params, - env_steps=training_state.env_steps + env_steps_per_actor_step) + env_steps=training_state.env_steps + env_steps_per_actor_step, + ) return (new_training_state, env_state, buffer_state, new_key), () return jax.lax.scan( - f, (training_state, env_state, buffer_state, key), (), - length=num_prefill_actor_steps)[0] + f, + (training_state, env_state, buffer_state, key), + (), + length=num_prefill_actor_steps, + )[0] prefill_replay_buffer = jax.pmap( - prefill_replay_buffer, axis_name=_PMAP_AXIS_NAME) + prefill_replay_buffer, axis_name=_PMAP_AXIS_NAME + ) def training_epoch( - training_state: TrainingState, env_state: envs.State, - buffer_state: ReplayBufferState, key: PRNGKey + training_state: TrainingState, + env_state: envs.State, + buffer_state: ReplayBufferState, + key: PRNGKey, ) -> Tuple[TrainingState, envs.State, ReplayBufferState, Metrics]: def f(carry, unused_t): @@ -377,8 +430,11 @@ def f(carry, unused_t): return (ts, es, bs, new_key), metrics (training_state, env_state, buffer_state, key), metrics = jax.lax.scan( - f, (training_state, env_state, buffer_state, key), (), - length=num_training_steps_per_epoch) + f, + (training_state, env_state, buffer_state, key), + (), + length=num_training_steps_per_epoch, + ) metrics = jax.tree_util.tree_map(jnp.mean, metrics) return training_state, env_state, buffer_state, metrics @@ -386,24 +442,28 @@ def f(carry, unused_t): # Note that this is NOT a pure jittable method. def training_epoch_with_timing( - training_state: TrainingState, env_state: envs.State, - buffer_state: ReplayBufferState, key: PRNGKey + training_state: TrainingState, + env_state: envs.State, + buffer_state: ReplayBufferState, + key: PRNGKey, ) -> Tuple[TrainingState, envs.State, ReplayBufferState, Metrics]: nonlocal training_walltime t = time.time() - (training_state, env_state, buffer_state, - metrics) = training_epoch(training_state, env_state, buffer_state, key) + (training_state, env_state, buffer_state, metrics) = training_epoch( + training_state, env_state, buffer_state, key + ) metrics = jax.tree_util.tree_map(jnp.mean, metrics) jax.tree_util.tree_map(lambda x: x.block_until_ready(), metrics) epoch_training_time = time.time() - t training_walltime += epoch_training_time - sps = (env_steps_per_actor_step * - num_training_steps_per_epoch) / epoch_training_time + sps = ( + env_steps_per_actor_step * num_training_steps_per_epoch + ) / epoch_training_time metrics = { 'training/sps': sps, 'training/walltime': training_walltime, - **{f'training/{name}': value for name, value in metrics.items()} + **{f'training/{name}': value for name, value in metrics.items()}, } return training_state, env_state, buffer_state, metrics # pytype: disable=bad-return-type # py311-upgrade @@ -418,20 +478,23 @@ def training_epoch_with_timing( sac_network=sac_network, alpha_optimizer=alpha_optimizer, policy_optimizer=policy_optimizer, - q_optimizer=q_optimizer) + q_optimizer=q_optimizer, + ) del global_key local_key, rb_key, env_key, eval_key = jax.random.split(local_key, 4) # Env init env_keys = jax.random.split(env_key, num_envs // jax.process_count()) - env_keys = jnp.reshape(env_keys, - (local_devices_to_use, -1) + env_keys.shape[1:]) + env_keys = jnp.reshape( + env_keys, (local_devices_to_use, -1) + env_keys.shape[1:] + ) env_state = jax.pmap(env.reset)(env_keys) # Replay buffer init buffer_state = jax.pmap(replay_buffer.init)( - jax.random.split(rb_key, local_devices_to_use)) + jax.random.split(rb_key, local_devices_to_use) + ) if not eval_env: eval_env = environment @@ -453,15 +516,18 @@ def training_epoch_with_timing( num_eval_envs=num_eval_envs, episode_length=episode_length, action_repeat=action_repeat, - key=eval_key) + key=eval_key, + ) # Run initial eval metrics = {} if process_id == 0 and num_evals > 1: metrics = evaluator.run_evaluation( _unpmap( - (training_state.normalizer_params, training_state.policy_params)), - training_metrics={}) + (training_state.normalizer_params, training_state.policy_params) + ), + training_metrics={}, + ) logging.info(metrics) progress_fn(0, metrics) @@ -470,10 +536,12 @@ def training_epoch_with_timing( prefill_key, local_key = jax.random.split(local_key) prefill_keys = jax.random.split(prefill_key, local_devices_to_use) training_state, env_state, buffer_state, _ = prefill_replay_buffer( - training_state, env_state, buffer_state, prefill_keys) + training_state, env_state, buffer_state, prefill_keys + ) - replay_size = jnp.sum(jax.vmap( - replay_buffer.size)(buffer_state)) * jax.process_count() + replay_size = ( + jnp.sum(jax.vmap(replay_buffer.size)(buffer_state)) * jax.process_count() + ) logging.info('replay size after prefill %s', replay_size) assert replay_size >= min_replay_size training_walltime = time.time() - t @@ -485,9 +553,11 @@ def training_epoch_with_timing( # Optimization epoch_key, local_key = jax.random.split(local_key) epoch_keys = jax.random.split(epoch_key, local_devices_to_use) - (training_state, env_state, buffer_state, - training_metrics) = training_epoch_with_timing(training_state, env_state, - buffer_state, epoch_keys) + (training_state, env_state, buffer_state, training_metrics) = ( + training_epoch_with_timing( + training_state, env_state, buffer_state, epoch_keys + ) + ) current_step = int(_unpmap(training_state.env_steps)) # Eval and logging @@ -495,15 +565,18 @@ def training_epoch_with_timing( if checkpoint_logdir: # Save current policy. params = _unpmap( - (training_state.normalizer_params, training_state.policy_params)) + (training_state.normalizer_params, training_state.policy_params) + ) path = f'{checkpoint_logdir}_sac_{current_step}.pkl' model.save_params(path, params) # Run evals. metrics = evaluator.run_evaluation( _unpmap( - (training_state.normalizer_params, training_state.policy_params)), - training_metrics) + (training_state.normalizer_params, training_state.policy_params) + ), + training_metrics, + ) logging.info(metrics) progress_fn(current_step, metrics) @@ -511,7 +584,8 @@ def training_epoch_with_timing( assert total_steps >= num_timesteps params = _unpmap( - (training_state.normalizer_params, training_state.policy_params)) + (training_state.normalizer_params, training_state.policy_params) + ) # If there was no mistakes the training_state should still be identical on all # devices. diff --git a/brax/training/agents/sac/train_test.py b/brax/training/agents/sac/train_test.py index a03e7b08..78a48274 100644 --- a/brax/training/agents/sac/train_test.py +++ b/brax/training/agents/sac/train_test.py @@ -44,7 +44,8 @@ def testTrain(self): reward_scaling=10, grad_updates_per_step=64, num_evals=3, - seed=0) + seed=0, + ) self.assertGreater(metrics['eval/episode_reward'], 140 * 0.995) self.assertEqual(fast.reset_count, 3) # type: ignore # once for prefill, once for train, once for eval @@ -58,12 +59,14 @@ def testNetworkEncoding(self, normalize_observations): num_timesteps=128, episode_length=128, num_envs=128, - normalize_observations=normalize_observations) + normalize_observations=normalize_observations, + ) normalize_fn = lambda x, y: x if normalize_observations: normalize_fn = running_statistics.normalize - sac_network = sac_networks.make_sac_networks(env.observation_size, - env.action_size, normalize_fn) + sac_network = sac_networks.make_sac_networks( + env.observation_size, env.action_size, normalize_fn + ) inference = sac_networks.make_inference_fn(sac_network) byte_encoding = pickle.dumps(params) decoded_params = pickle.loads(byte_encoding) @@ -71,7 +74,8 @@ def testNetworkEncoding(self, normalize_observations): # Compute one action. state = env.reset(jax.random.PRNGKey(0)) original_action = original_inference(decoded_params)( - state.obs, jax.random.PRNGKey(0))[0] + state.obs, jax.random.PRNGKey(0) + )[0] action = inference(decoded_params)(state.obs, jax.random.PRNGKey(0))[0] self.assertSequenceEqual(original_action, action) env.step(state, action) diff --git a/brax/training/distribution.py b/brax/training/distribution.py index 8f03d3ae..971a4a9a 100644 --- a/brax/training/distribution.py +++ b/brax/training/distribution.py @@ -85,7 +85,8 @@ def entropy(self, parameters, seed): dist = self.create_dist(parameters) entropy = dist.entropy() entropy += self._postprocessor.forward_log_det_jacobian( - dist.sample(seed=seed)) + dist.sample(seed=seed) + ) if self._event_ndims == 1: entropy = jnp.sum(entropy, axis=-1) return entropy @@ -106,11 +107,11 @@ def mode(self): def log_prob(self, x): log_unnormalized = -0.5 * jnp.square(x / self.scale - self.loc / self.scale) - log_normalization = 0.5 * jnp.log(2. * jnp.pi) + jnp.log(self.scale) + log_normalization = 0.5 * jnp.log(2.0 * jnp.pi) + jnp.log(self.scale) return log_unnormalized - log_normalization def entropy(self): - log_normalization = 0.5 * jnp.log(2. * jnp.pi) + jnp.log(self.scale) + log_normalization = 0.5 * jnp.log(2.0 * jnp.pi) + jnp.log(self.scale) entropy = 0.5 + log_normalization return entropy * jnp.ones_like(self.loc) @@ -125,7 +126,7 @@ def inverse(self, y): return jnp.arctanh(y) def forward_log_det_jacobian(self, x): - return 2. * (jnp.log(2.) - x - jax.nn.softplus(-2. * x)) + return 2.0 * (jnp.log(2.0) - x - jax.nn.softplus(-2.0 * x)) class NormalTanhDistribution(ParametricDistribution): @@ -150,7 +151,8 @@ def __init__(self, event_size, min_std=0.001, var_scale=1): param_size=2 * event_size, postprocessor=TanhBijector(), event_ndims=1, - reparametrizable=True) + reparametrizable=True, + ) self._min_std = min_std self._var_scale = var_scale diff --git a/brax/training/gradients.py b/brax/training/gradients.py index 3b035616..185419fb 100644 --- a/brax/training/gradients.py +++ b/brax/training/gradients.py @@ -20,9 +20,11 @@ import optax -def loss_and_pgrad(loss_fn: Callable[..., float], - pmap_axis_name: Optional[str], - has_aux: bool = False): +def loss_and_pgrad( + loss_fn: Callable[..., float], + pmap_axis_name: Optional[str], + has_aux: bool = False, +): g = jax.value_and_grad(loss_fn, has_aux=has_aux) def h(*args, **kwargs): @@ -32,10 +34,12 @@ def h(*args, **kwargs): return g if pmap_axis_name is None else h -def gradient_update_fn(loss_fn: Callable[..., float], - optimizer: optax.GradientTransformation, - pmap_axis_name: Optional[str], - has_aux: bool = False): +def gradient_update_fn( + loss_fn: Callable[..., float], + optimizer: optax.GradientTransformation, + pmap_axis_name: Optional[str], + has_aux: bool = False, +): """Wrapper of the loss function that apply gradient updates. Args: @@ -51,7 +55,8 @@ def gradient_update_fn(loss_fn: Callable[..., float], and the new optimizer state. """ loss_and_pgrad_fn = loss_and_pgrad( - loss_fn, pmap_axis_name=pmap_axis_name, has_aux=has_aux) + loss_fn, pmap_axis_name=pmap_axis_name, has_aux=has_aux + ) def f(*args, optimizer_state): value, grads = loss_and_pgrad_fn(*args) diff --git a/brax/training/learner.py b/brax/training/learner.py index 53a788cc..62fa1a28 100644 --- a/brax/training/learner.py +++ b/brax/training/learner.py @@ -319,6 +319,7 @@ def main(unused_argv): # Output an episode trajectory. env = get_environment(_ENV.value) + @jax.jit def jit_next_state(state, key): new_key, tmp_key = jax.random.split(key) diff --git a/brax/training/networks.py b/brax/training/networks.py index d439ee1c..1a215656 100644 --- a/brax/training/networks.py +++ b/brax/training/networks.py @@ -40,6 +40,7 @@ class FeedForwardNetwork: class MLP(linen.Module): """MLP module.""" + layer_sizes: Sequence[int] activation: ActivationFn = linen.relu kernel_init: Initializer = jax.nn.initializers.lecun_uniform() @@ -55,8 +56,8 @@ def __call__(self, data: jnp.ndarray): hidden_size, name=f'hidden_{i}', kernel_init=self.kernel_init, - use_bias=self.bias)( - hidden) + use_bias=self.bias, + )(hidden) if i != len(self.layer_sizes) - 1 or self.activate_final: hidden = self.activation(hidden) if self.layer_norm: @@ -66,6 +67,7 @@ def __call__(self, data: jnp.ndarray): class SNMLP(linen.Module): """MLP module with Spectral Normalization.""" + layer_sizes: Sequence[int] activation: ActivationFn = linen.relu kernel_init: Initializer = jax.nn.initializers.lecun_uniform() @@ -80,8 +82,8 @@ def __call__(self, data: jnp.ndarray): hidden_size, name=f'hidden_{i}', kernel_init=self.kernel_init, - use_bias=self.bias)( - hidden) + use_bias=self.bias, + )(hidden) if i != len(self.layer_sizes) - 1 or self.activate_final: hidden = self.activation(hidden) return hidden @@ -182,19 +184,20 @@ def _get_obs_state_size(obs_size: types.ObservationSize, obs_key: str) -> int: def make_policy_network( param_size: int, obs_size: types.ObservationSize, - preprocess_observations_fn: types.PreprocessObservationFn = types - .identity_observation_preprocessor, + preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, hidden_layer_sizes: Sequence[int] = (256, 256), activation: ActivationFn = linen.relu, kernel_init: Initializer = jax.nn.initializers.lecun_uniform(), layer_norm: bool = False, - obs_key: str = 'state') -> FeedForwardNetwork: + obs_key: str = 'state', +) -> FeedForwardNetwork: """Creates a policy network.""" policy_module = MLP( layer_sizes=list(hidden_layer_sizes) + [param_size], activation=activation, kernel_init=kernel_init, - layer_norm=layer_norm) + layer_norm=layer_norm, + ) def apply(processor_params, policy_params, obs): obs = preprocess_observations_fn(obs, processor_params) @@ -204,21 +207,23 @@ def apply(processor_params, policy_params, obs): obs_size = _get_obs_state_size(obs_size, obs_key) dummy_obs = jnp.zeros((1, obs_size)) return FeedForwardNetwork( - init=lambda key: policy_module.init(key, dummy_obs), apply=apply) + init=lambda key: policy_module.init(key, dummy_obs), apply=apply + ) def make_value_network( obs_size: types.ObservationSize, - preprocess_observations_fn: types.PreprocessObservationFn = types - .identity_observation_preprocessor, + preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, hidden_layer_sizes: Sequence[int] = (256, 256), activation: ActivationFn = linen.relu, - obs_key: str = 'state') -> FeedForwardNetwork: + obs_key: str = 'state', +) -> FeedForwardNetwork: """Creates a value network.""" value_module = MLP( layer_sizes=list(hidden_layer_sizes) + [1], activation=activation, - kernel_init=jax.nn.initializers.lecun_uniform()) + kernel_init=jax.nn.initializers.lecun_uniform(), + ) def apply(processor_params, value_params, obs): obs = preprocess_observations_fn(obs, processor_params) @@ -228,22 +233,24 @@ def apply(processor_params, value_params, obs): obs_size = _get_obs_state_size(obs_size, obs_key) dummy_obs = jnp.zeros((1, obs_size)) return FeedForwardNetwork( - init=lambda key: value_module.init(key, dummy_obs), apply=apply) + init=lambda key: value_module.init(key, dummy_obs), apply=apply + ) def make_q_network( obs_size: types.ObservationSize, action_size: int, - preprocess_observations_fn: types.PreprocessObservationFn = types - .identity_observation_preprocessor, + preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, hidden_layer_sizes: Sequence[int] = (256, 256), activation: ActivationFn = linen.relu, n_critics: int = 2, - layer_norm: bool = False) -> FeedForwardNetwork: + layer_norm: bool = False, +) -> FeedForwardNetwork: """Creates a value network.""" class QModule(linen.Module): """Q Module.""" + n_critics: int @linen.compact @@ -255,8 +262,8 @@ def __call__(self, obs: jnp.ndarray, actions: jnp.ndarray): layer_sizes=list(hidden_layer_sizes) + [1], activation=activation, kernel_init=jax.nn.initializers.lecun_uniform(), - layer_norm=layer_norm)( - hidden) + layer_norm=layer_norm, + )(hidden) res.append(q) return jnp.concatenate(res, axis=-1) @@ -269,7 +276,8 @@ def apply(processor_params, q_params, obs, actions): dummy_obs = jnp.zeros((1, obs_size)) dummy_action = jnp.zeros((1, action_size)) return FeedForwardNetwork( - init=lambda key: q_module.init(key, dummy_obs, dummy_action), apply=apply) + init=lambda key: q_module.init(key, dummy_obs, dummy_action), apply=apply + ) def make_model( @@ -290,25 +298,28 @@ def make_model( a model """ warnings.warn( - 'make_model is deprecated, use make_{policy|q|value}_network instead.') + 'make_model is deprecated, use make_{policy|q|value}_network instead.' + ) dummy_obs = jnp.zeros((1, obs_size)) if spectral_norm: module = SNMLP(layer_sizes=layer_sizes, activation=activation) model = FeedForwardNetwork( - init=lambda rng1, rng2: module.init({ - 'params': rng1, - 'sing_vec': rng2 - }, dummy_obs), - apply=module.apply) + init=lambda rng1, rng2: module.init( + {'params': rng1, 'sing_vec': rng2}, dummy_obs + ), + apply=module.apply, + ) else: module = MLP(layer_sizes=layer_sizes, activation=activation) model = FeedForwardNetwork( - init=lambda rng: module.init(rng, dummy_obs), apply=module.apply) + init=lambda rng: module.init(rng, dummy_obs), apply=module.apply + ) return model -def make_models(policy_params_size: int, - obs_size: int) -> Tuple[FeedForwardNetwork, FeedForwardNetwork]: +def make_models( + policy_params_size: int, obs_size: int +) -> Tuple[FeedForwardNetwork, FeedForwardNetwork]: """Creates models for policy and value functions. Args: @@ -319,7 +330,8 @@ def make_models(policy_params_size: int, a model for policy and a model for value function """ warnings.warn( - 'make_models is deprecated, use make_{policy|q|value}_network instead.') + 'make_models is deprecated, use make_{policy|q|value}_network instead.' + ) policy_model = make_model([32, 32, 32, 32, policy_params_size], obs_size) value_model = make_model([256, 256, 256, 256, 256, 1], obs_size) return policy_model, value_model diff --git a/brax/training/pmap.py b/brax/training/pmap.py index 037d4c35..2d6f08d4 100644 --- a/brax/training/pmap.py +++ b/brax/training/pmap.py @@ -53,9 +53,9 @@ def is_replicated(x: Any, axis_name: str) -> jnp.ndarray: boolean whether x is replicated. """ fp = _fingerprint(x) - return jax.lax.pmin( - fp, axis_name=axis_name) == jax.lax.pmax( - fp, axis_name=axis_name) + return jax.lax.pmin(fp, axis_name=axis_name) == jax.lax.pmax( + fp, axis_name=axis_name + ) def assert_is_replicated(x: Any, debug: Any = None): diff --git a/brax/training/replay_buffers.py b/brax/training/replay_buffers.py index 2707f1af..fbc6e4bd 100644 --- a/brax/training/replay_buffers.py +++ b/brax/training/replay_buffers.py @@ -234,7 +234,9 @@ def sample_internal( # Note that this may be out of bound, but the operations below would still # work fine as they take this number modulo the buffer size. - idx = (jnp.arange(self._sample_batch_size) + buffer_state.sample_position) % buffer_state.insert_position + idx = ( + jnp.arange(self._sample_batch_size) + buffer_state.sample_position + ) % buffer_state.insert_position flat_batch = jnp.take(buffer_state.data, idx, axis=0, mode='wrap') @@ -404,7 +406,9 @@ def sample(buffer_state: State) -> Tuple[State, Sample]: def size(buffer_state: State) -> int: return jnp.sum(jax.vmap(self._buffer.size)(buffer_state)) # pytype: disable=bad-return-type # jnp-type - partition_spec = jax.sharding.PartitionSpec((axis_names),) + partition_spec = jax.sharding.PartitionSpec( + (axis_names), + ) self._partitioned_init = pjit.pjit(init, out_shardings=partition_spec) self._partitioned_insert = pjit.pjit( insert, diff --git a/brax/training/spectral_norm.py b/brax/training/spectral_norm.py index 7961c513..d87bdfcb 100644 --- a/brax/training/spectral_norm.py +++ b/brax/training/spectral_norm.py @@ -20,6 +20,7 @@ - https://arxiv.org/abs/1802.05957 - https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/spectral_norm.py """ + from typing import Any, Callable, Tuple from brax.training.types import PRNGKey @@ -46,6 +47,7 @@ def _l2_normalize(x, axis=None, eps=1e-12): vectors in a batch. Passing `None` views `t` as a flattened vector when calculating the norm (equivalent to Frobenius norm). eps: Epsilon to avoid dividing by zero. + Returns: An array of the same shape as 'x' L2-normalized along 'axis'. """ @@ -70,6 +72,7 @@ class SNDense(linen.Module): n_steps: How many steps of power iteration to perform to approximate the singular value of the input. """ + features: int use_bias: bool = True dtype: Any = jnp.float32 @@ -90,22 +93,24 @@ def __call__(self, inputs: Array) -> Array: The transformed input. """ inputs = jnp.asarray(inputs, self.dtype) - kernel = self.param('kernel', - self.kernel_init, - (inputs.shape[-1], self.features)) + kernel = self.param( + 'kernel', self.kernel_init, (inputs.shape[-1], self.features) + ) kernel = jnp.asarray(kernel, self.dtype) kernel_shape = kernel.shape # Handle scalars. if kernel.ndim <= 1: - raise ValueError('Spectral normalization is not well defined for ' - 'scalar inputs.') + raise ValueError( + 'Spectral normalization is not well defined for scalar inputs.' + ) # Handle higher-order tensors. elif kernel.ndim > 2: kernel = jnp.reshape(kernel, [-1, kernel.shape[-1]]) key = self.make_rng('sing_vec') - u0_state = self.variable('sing_vec', 'u0', normal(stddev=1.), key, - (1, kernel.shape[-1])) + u0_state = self.variable( + 'sing_vec', 'u0', normal(stddev=1.0), key, (1, kernel.shape[-1]) + ) u0 = u0_state.value # Power iteration for the weight's singular value. @@ -123,9 +128,12 @@ def __call__(self, inputs: Array) -> Array: u0_state.value = u0 - y = lax.dot_general(inputs, kernel, - (((inputs.ndim - 1,), (0,)), ((), ())), - precision=self.precision) + y = lax.dot_general( + inputs, + kernel, + (((inputs.ndim - 1,), (0,)), ((), ())), + precision=self.precision, + ) if self.use_bias: bias = self.param('bias', self.bias_init, (self.features,)) bias = jnp.asarray(bias, self.dtype) diff --git a/brax/training/types.py b/brax/training/types.py index 61839cf6..19de3a7d 100644 --- a/brax/training/types.py +++ b/brax/training/types.py @@ -42,6 +42,7 @@ class Transition(NamedTuple): """Container for a transition.""" + observation: NestedArray action: NestedArray reward: NestedArray @@ -70,8 +71,9 @@ def __call__( pass -def identity_observation_preprocessor(observation: Observation, - preprocessor_params: PreprocessorParams): +def identity_observation_preprocessor( + observation: Observation, preprocessor_params: PreprocessorParams +): del preprocessor_params return observation @@ -82,7 +84,6 @@ def __call__( self, observation_size: ObservationSize, action_size: int, - preprocess_observations_fn: - PreprocessObservationFn = identity_observation_preprocessor + preprocess_observations_fn: PreprocessObservationFn = identity_observation_preprocessor, ) -> NetworkType: pass