From 2171b087a7a354ea7488256d979a344f029afcfc Mon Sep 17 00:00:00 2001 From: James Cotton Date: Fri, 18 Nov 2022 13:51:14 +0000 Subject: [PATCH 01/10] SHAC: check in stub based on PPO --- brax/training/agents/shac/__init__.py | 14 + brax/training/agents/shac/losses.py | 182 +++++++++++++ brax/training/agents/shac/networks.py | 88 ++++++ brax/training/agents/shac/train.py | 341 ++++++++++++++++++++++++ brax/training/agents/shac/train_test.py | 79 ++++++ 5 files changed, 704 insertions(+) create mode 100644 brax/training/agents/shac/__init__.py create mode 100644 brax/training/agents/shac/losses.py create mode 100644 brax/training/agents/shac/networks.py create mode 100644 brax/training/agents/shac/train.py create mode 100644 brax/training/agents/shac/train_test.py diff --git a/brax/training/agents/shac/__init__.py b/brax/training/agents/shac/__init__.py new file mode 100644 index 00000000..6d7c8bbb --- /dev/null +++ b/brax/training/agents/shac/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2022 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/brax/training/agents/shac/losses.py b/brax/training/agents/shac/losses.py new file mode 100644 index 00000000..056eae1e --- /dev/null +++ b/brax/training/agents/shac/losses.py @@ -0,0 +1,182 @@ +# Copyright 2022 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Proximal policy optimization training. + +See: https://arxiv.org/pdf/1707.06347.pdf +""" + +from typing import Any, Tuple + +from brax.training import types +from brax.training.agents.shac import networks as shac_networks +from brax.training.types import Params +import flax +import jax +import jax.numpy as jnp + + +@flax.struct.dataclass +class SHACNetworkParams: + """Contains training state for the learner.""" + policy: Params + value: Params + + +def compute_gae(truncation: jnp.ndarray, + termination: jnp.ndarray, + rewards: jnp.ndarray, + values: jnp.ndarray, + bootstrap_value: jnp.ndarray, + lambda_: float = 1.0, + discount: float = 0.99): + """Calculates the Generalized Advantage Estimation (GAE). + + Args: + truncation: A float32 tensor of shape [T, B] with truncation signal. + termination: A float32 tensor of shape [T, B] with termination signal. + rewards: A float32 tensor of shape [T, B] containing rewards generated by + following the behaviour policy. + values: A float32 tensor of shape [T, B] with the value function estimates + wrt. the target policy. + bootstrap_value: A float32 of shape [B] with the value function estimate at + time T. + lambda_: Mix between 1-step (lambda_=0) and n-step (lambda_=1). Defaults to + lambda_=1. + discount: TD discount. + + Returns: + A float32 tensor of shape [T, B]. Can be used as target to + train a baseline (V(x_t) - vs_t)^2. + A float32 tensor of shape [T, B] of advantages. + """ + + truncation_mask = 1 - truncation + # Append bootstrapped value to get [v1, ..., v_t+1] + values_t_plus_1 = jnp.concatenate( + [values[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0) + deltas = rewards + discount * (1 - termination) * values_t_plus_1 - values + deltas *= truncation_mask + + acc = jnp.zeros_like(bootstrap_value) + vs_minus_v_xs = [] + + def compute_vs_minus_v_xs(carry, target_t): + lambda_, acc = carry + truncation_mask, delta, termination = target_t + acc = delta + discount * (1 - termination) * truncation_mask * lambda_ * acc + return (lambda_, acc), (acc) + + (_, _), (vs_minus_v_xs) = jax.lax.scan( + compute_vs_minus_v_xs, (lambda_, acc), + (truncation_mask, deltas, termination), + length=int(truncation_mask.shape[0]), + reverse=True) + # Add V(x_s) to get v_s. + vs = jnp.add(vs_minus_v_xs, values) + + vs_t_plus_1 = jnp.concatenate( + [vs[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0) + advantages = (rewards + discount * + (1 - termination) * vs_t_plus_1 - values) * truncation_mask + return jax.lax.stop_gradient(vs), jax.lax.stop_gradient(advantages) + + +def compute_shac_loss( + params: SHACNetworkParams, + normalizer_params: Any, + data: types.Transition, + rng: jnp.ndarray, + shac_network: shac_networks.SHACNetworks, + entropy_cost: float = 1e-4, + discounting: float = 0.9, + reward_scaling: float = 1.0, + gae_lambda: float = 0.95, + clipping_epsilon: float = 0.3, + normalize_advantage: bool = True) -> Tuple[jnp.ndarray, types.Metrics]: + """Computes SHAC loss. + + Args: + params: Network parameters, + normalizer_params: Parameters of the normalizer. + data: Transition that with leading dimension [B, T]. extra fields required + are ['state_extras']['truncation'] ['policy_extras']['raw_action'] + ['policy_extras']['log_prob'] + rng: Random key + shac_network: SHAC networks. + entropy_cost: entropy cost. + discounting: discounting, + reward_scaling: reward multiplier. + gae_lambda: General advantage estimation lambda. + clipping_epsilon: Policy loss clipping epsilon + normalize_advantage: whether to normalize advantage estimate + + Returns: + A tuple (loss, metrics) + """ + parametric_action_distribution = shac_network.parametric_action_distribution + policy_apply = shac_network.policy_network.apply + value_apply = shac_network.value_network.apply + + # Put the time dimension first. + data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 0, 1), data) + policy_logits = policy_apply(normalizer_params, params.policy, + data.observation) + + baseline = value_apply(normalizer_params, params.value, data.observation) + + bootstrap_value = value_apply(normalizer_params, params.value, + data.next_observation[-1]) + + rewards = data.reward * reward_scaling + truncation = data.extras['state_extras']['truncation'] + termination = (1 - data.discount) * (1 - truncation) + + target_action_log_probs = parametric_action_distribution.log_prob( + policy_logits, data.extras['policy_extras']['raw_action']) + behaviour_action_log_probs = data.extras['policy_extras']['log_prob'] + + vs, advantages = compute_gae( + truncation=truncation, + termination=termination, + rewards=rewards, + values=baseline, + bootstrap_value=bootstrap_value, + lambda_=gae_lambda, + discount=discounting) + if normalize_advantage: + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + rho_s = jnp.exp(target_action_log_probs - behaviour_action_log_probs) + + surrogate_loss1 = rho_s * advantages + surrogate_loss2 = jnp.clip(rho_s, 1 - clipping_epsilon, + 1 + clipping_epsilon) * advantages + + policy_loss = -jnp.mean(jnp.minimum(surrogate_loss1, surrogate_loss2)) + + # Value function loss + v_error = vs - baseline + v_loss = jnp.mean(v_error * v_error) * 0.5 * 0.5 + + # Entropy reward + entropy = jnp.mean(parametric_action_distribution.entropy(policy_logits, rng)) + entropy_loss = entropy_cost * -entropy + + total_loss = policy_loss + v_loss + entropy_loss + return total_loss, { + 'total_loss': total_loss, + 'policy_loss': policy_loss, + 'v_loss': v_loss, + 'entropy_loss': entropy_loss + } diff --git a/brax/training/agents/shac/networks.py b/brax/training/agents/shac/networks.py new file mode 100644 index 00000000..9240bc0f --- /dev/null +++ b/brax/training/agents/shac/networks.py @@ -0,0 +1,88 @@ +# Copyright 2022 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SHAC networks.""" + +from typing import Sequence, Tuple + +from brax.training import distribution +from brax.training import networks +from brax.training import types +from brax.training.types import PRNGKey +import flax +from flax import linen + + +@flax.struct.dataclass +class SHACNetworks: + policy_network: networks.FeedForwardNetwork + value_network: networks.FeedForwardNetwork + parametric_action_distribution: distribution.ParametricDistribution + + +def make_inference_fn(shac_networks: SHACNetworks): + """Creates params and inference function for the SHAC agent.""" + + def make_policy(params: types.PolicyParams, + deterministic: bool = False) -> types.Policy: + policy_network = shac_networks.policy_network + parametric_action_distribution = shac_networks.parametric_action_distribution + + def policy(observations: types.Observation, + key_sample: PRNGKey) -> Tuple[types.Action, types.Extra]: + logits = policy_network.apply(*params, observations) + if deterministic: + return shac_networks.parametric_action_distribution.mode(logits), {} + raw_actions = parametric_action_distribution.sample_no_postprocessing( + logits, key_sample) + log_prob = parametric_action_distribution.log_prob(logits, raw_actions) + postprocessed_actions = parametric_action_distribution.postprocess( + raw_actions) + return postprocessed_actions, { + 'log_prob': log_prob, + 'raw_action': raw_actions + } + + return policy + + return make_policy + + +def make_shac_networks( + observation_size: int, + action_size: int, + preprocess_observations_fn: types.PreprocessObservationFn = types + .identity_observation_preprocessor, + policy_hidden_layer_sizes: Sequence[int] = (32,) * 4, + value_hidden_layer_sizes: Sequence[int] = (256,) * 5, + activation: networks.ActivationFn = linen.swish) -> SHACNetworks: + """Make SHAC networks with preprocessor.""" + parametric_action_distribution = distribution.NormalTanhDistribution( + event_size=action_size) + policy_network = networks.make_policy_network( + parametric_action_distribution.param_size, + observation_size, + preprocess_observations_fn=preprocess_observations_fn, + hidden_layer_sizes=policy_hidden_layer_sizes, + activation=activation) + value_network = networks.make_value_network( + observation_size, + preprocess_observations_fn=preprocess_observations_fn, + hidden_layer_sizes=value_hidden_layer_sizes, + activation=activation) + + return SHACNetworks( + policy_network=policy_network, + value_network=value_network, + parametric_action_distribution=parametric_action_distribution) diff --git a/brax/training/agents/shac/train.py b/brax/training/agents/shac/train.py new file mode 100644 index 00000000..46df46d6 --- /dev/null +++ b/brax/training/agents/shac/train.py @@ -0,0 +1,341 @@ +# Copyright 2022 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Short-Horizon Actor Critic. + +See: https://arxiv.org/pdf/2204.07137.pdf +and https://github.com/NVlabs/DiffRL/blob/main/algorithms/shac.py +""" + +import functools +import time +from typing import Callable, Optional, Tuple + +from absl import logging +from brax import envs +from brax.envs import wrappers +from brax.training import acting +from brax.training import gradients +from brax.training import pmap +from brax.training import types +from brax.training.acme import running_statistics +from brax.training.acme import specs +from brax.training.agents.shac import losses as shac_losses +from brax.training.agents.shac import networks as shac_networks +from brax.training.types import Params +from brax.training.types import PRNGKey +import flax +import jax +import jax.numpy as jnp +import optax + +InferenceParams = Tuple[running_statistics.NestedMeanStd, Params] +Metrics = types.Metrics + +_PMAP_AXIS_NAME = 'i' + + +@flax.struct.dataclass +class TrainingState: + """Contains training state for the learner.""" + optimizer_state: optax.OptState + params: shac_losses.SHACNetworkParams + normalizer_params: running_statistics.RunningStatisticsState + env_steps: jnp.ndarray + + +def _unpmap(v): + return jax.tree_util.tree_map(lambda x: x[0], v) + + +def train(environment: envs.Env, + num_timesteps: int, + episode_length: int, + action_repeat: int = 1, + num_envs: int = 1, + max_devices_per_host: Optional[int] = None, + num_eval_envs: int = 128, + learning_rate: float = 1e-4, + entropy_cost: float = 1e-4, + discounting: float = 0.9, + seed: int = 0, + unroll_length: int = 10, + batch_size: int = 32, + num_minibatches: int = 16, + num_updates_per_batch: int = 2, + num_evals: int = 1, + normalize_observations: bool = False, + reward_scaling: float = 1., + clipping_epsilon: float = .3, + gae_lambda: float = .95, + deterministic_eval: bool = False, + network_factory: types.NetworkFactory[ + shac_networks.SHACNetworks] = shac_networks.make_shac_networks, + progress_fn: Callable[[int, Metrics], None] = lambda *args: None, + normalize_advantage: bool = True, + eval_env: Optional[envs.Env] = None): + """SHAC training.""" + assert batch_size * num_minibatches % num_envs == 0 + xt = time.time() + + process_count = jax.process_count() + process_id = jax.process_index() + local_device_count = jax.local_device_count() + local_devices_to_use = local_device_count + if max_devices_per_host: + local_devices_to_use = min(local_devices_to_use, max_devices_per_host) + logging.info( + 'Device count: %d, process count: %d (id %d), local device count: %d, ' + 'devices to be used count: %d', jax.device_count(), process_count, + process_id, local_device_count, local_devices_to_use) + device_count = local_devices_to_use * process_count + + # The number of environment steps executed for every training step. + env_step_per_training_step = ( + batch_size * unroll_length * num_minibatches * action_repeat) + num_evals_after_init = max(num_evals - 1, 1) + # The number of training_step calls per training_epoch call. + # equals to ceil(num_timesteps / (num_evals * env_step_per_training_step)) + num_training_steps_per_epoch = -( + -num_timesteps // (num_evals_after_init * env_step_per_training_step)) + + assert num_envs % device_count == 0 + env = environment + + env = wrappers.wrap_for_training( + env, episode_length=episode_length, action_repeat=action_repeat) + + reset_fn = jax.jit(jax.vmap(env.reset)) + + normalize = lambda x, y: x + if normalize_observations: + normalize = running_statistics.normalize + shac_network = network_factory( + env.observation_size, + env.action_size, + preprocess_observations_fn=normalize) + make_policy = shac_networks.make_inference_fn(shac_network) + + optimizer = optax.adam(learning_rate=learning_rate) + + loss_fn = functools.partial( + shac_losses.compute_shac_loss, + shac_network=shac_network, + entropy_cost=entropy_cost, + discounting=discounting, + reward_scaling=reward_scaling, + gae_lambda=gae_lambda, + clipping_epsilon=clipping_epsilon, + normalize_advantage=normalize_advantage) + + gradient_update_fn = gradients.gradient_update_fn( + loss_fn, optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True) + + def minibatch_step( + carry, data: types.Transition, + normalizer_params: running_statistics.RunningStatisticsState): + optimizer_state, params, key = carry + key, key_loss = jax.random.split(key) + (_, metrics), params, optimizer_state = gradient_update_fn( + params, + normalizer_params, + data, + key_loss, + optimizer_state=optimizer_state) + + return (optimizer_state, params, key), metrics + + def sgd_step(carry, unused_t, data: types.Transition, + normalizer_params: running_statistics.RunningStatisticsState): + optimizer_state, params, key = carry + key, key_perm, key_grad = jax.random.split(key, 3) + + def convert_data(x: jnp.ndarray): + x = jax.random.permutation(key_perm, x) + x = jnp.reshape(x, (num_minibatches, -1) + x.shape[1:]) + return x + + shuffled_data = jax.tree_util.tree_map(convert_data, data) + (optimizer_state, params, _), metrics = jax.lax.scan( + functools.partial(minibatch_step, normalizer_params=normalizer_params), + (optimizer_state, params, key_grad), + shuffled_data, + length=num_minibatches) + return (optimizer_state, params, key), metrics + + def training_step( + carry: Tuple[TrainingState, envs.State, PRNGKey], + unused_t) -> Tuple[Tuple[TrainingState, envs.State, PRNGKey], Metrics]: + training_state, state, key = carry + key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3) + + policy = make_policy( + (training_state.normalizer_params, training_state.params.policy)) + + def f(carry, unused_t): + current_state, current_key = carry + current_key, next_key = jax.random.split(current_key) + next_state, data = acting.generate_unroll( + env, + current_state, + policy, + current_key, + unroll_length, + extra_fields=('truncation',)) + return (next_state, next_key), data + + (state, _), data = jax.lax.scan( + f, (state, key_generate_unroll), (), + length=batch_size * num_minibatches // num_envs) + # Have leading dimentions (batch_size * num_minibatches, unroll_length) + data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 1, 2), data) + data = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), + data) + assert data.discount.shape[1:] == (unroll_length,) + + # Update normalization params and normalize observations. + normalizer_params = running_statistics.update( + training_state.normalizer_params, + data.observation, + pmap_axis_name=_PMAP_AXIS_NAME) + + (optimizer_state, params, _), metrics = jax.lax.scan( + functools.partial( + sgd_step, data=data, normalizer_params=normalizer_params), + (training_state.optimizer_state, training_state.params, key_sgd), (), + length=num_updates_per_batch) + + new_training_state = TrainingState( + optimizer_state=optimizer_state, + params=params, + normalizer_params=normalizer_params, + env_steps=training_state.env_steps + env_step_per_training_step) + return (new_training_state, state, new_key), metrics + + def training_epoch(training_state: TrainingState, state: envs.State, + key: PRNGKey) -> Tuple[TrainingState, envs.State, Metrics]: + (training_state, state, _), loss_metrics = jax.lax.scan( + training_step, (training_state, state, key), (), + length=num_training_steps_per_epoch) + loss_metrics = jax.tree_util.tree_map(jnp.mean, loss_metrics) + return training_state, state, loss_metrics + + training_epoch = jax.pmap(training_epoch, axis_name=_PMAP_AXIS_NAME) + + # Note that this is NOT a pure jittable method. + def training_epoch_with_timing( + training_state: TrainingState, env_state: envs.State, + key: PRNGKey) -> Tuple[TrainingState, envs.State, Metrics]: + nonlocal training_walltime + t = time.time() + (training_state, env_state, + metrics) = training_epoch(training_state, env_state, key) + metrics = jax.tree_util.tree_map(jnp.mean, metrics) + jax.tree_util.tree_map(lambda x: x.block_until_ready(), metrics) + + epoch_training_time = time.time() - t + training_walltime += epoch_training_time + sps = (num_training_steps_per_epoch * + env_step_per_training_step) / epoch_training_time + metrics = { + 'training/sps': sps, + 'training/walltime': training_walltime, + **{f'training/{name}': value for name, value in metrics.items()} + } + return training_state, env_state, metrics + + key = jax.random.PRNGKey(seed) + global_key, local_key = jax.random.split(key) + del key + local_key = jax.random.fold_in(local_key, process_id) + local_key, key_env, eval_key = jax.random.split(local_key, 3) + # key_networks should be global, so that networks are initialized the same + # way for different processes. + key_policy, key_value = jax.random.split(global_key) + del global_key + + init_params = shac_losses.SHACNetworkParams( + policy=shac_network.policy_network.init(key_policy), + value=shac_network.value_network.init(key_value)) + training_state = TrainingState( + optimizer_state=optimizer.init(init_params), + params=init_params, + normalizer_params=running_statistics.init_state( + specs.Array((env.observation_size,), jnp.float32)), + env_steps=0) + training_state = jax.device_put_replicated( + training_state, + jax.local_devices()[:local_devices_to_use]) + + key_envs = jax.random.split(key_env, num_envs // process_count) + key_envs = jnp.reshape(key_envs, + (local_devices_to_use, -1) + key_envs.shape[1:]) + env_state = reset_fn(key_envs) + + if not eval_env: + eval_env = env + else: + eval_env = wrappers.wrap_for_training( + eval_env, episode_length=episode_length, action_repeat=action_repeat) + + evaluator = acting.Evaluator( + eval_env, + functools.partial(make_policy, deterministic=deterministic_eval), + num_eval_envs=num_eval_envs, + episode_length=episode_length, + action_repeat=action_repeat, + key=eval_key) + + # Run initial eval + if process_id == 0 and num_evals > 1: + metrics = evaluator.run_evaluation( + _unpmap( + (training_state.normalizer_params, training_state.params.policy)), + training_metrics={}) + logging.info(metrics) + progress_fn(0, metrics) + + training_walltime = 0 + current_step = 0 + for it in range(num_evals_after_init): + logging.info('starting iteration %s %s', it, time.time() - xt) + + # optimization + epoch_key, local_key = jax.random.split(local_key) + epoch_keys = jax.random.split(epoch_key, local_devices_to_use) + (training_state, env_state, + training_metrics) = training_epoch_with_timing(training_state, env_state, + epoch_keys) + current_step = int(_unpmap(training_state.env_steps)) + + if process_id == 0: + # Run evals. + metrics = evaluator.run_evaluation( + _unpmap( + (training_state.normalizer_params, training_state.params.policy)), + training_metrics) + logging.info(metrics) + progress_fn(current_step, metrics) + + total_steps = current_step + assert total_steps >= num_timesteps + + # If there was no mistakes the training_state should still be identical on all + # devices. + pmap.assert_is_replicated(training_state) + params = _unpmap( + (training_state.normalizer_params, training_state.params.policy)) + logging.info('total steps: %s', total_steps) + pmap.synchronize_hosts() + return (make_policy, params, metrics) diff --git a/brax/training/agents/shac/train_test.py b/brax/training/agents/shac/train_test.py new file mode 100644 index 00000000..781e2ea0 --- /dev/null +++ b/brax/training/agents/shac/train_test.py @@ -0,0 +1,79 @@ +# Copyright 2022 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SHAC tests.""" +import pickle + +from absl.testing import absltest +from absl.testing import parameterized +from brax import envs +from brax.training.acme import running_statistics +from brax.training.agents.shac import networks as shac_networks +from brax.training.agents.shac import train as shac +import jax + + +class SHACTest(parameterized.TestCase): + """Tests for SHAC module.""" + + + def testTrain(self): + """Test SHAC with a simple env.""" + _, _, metrics = shac.train( + envs.get_environment('fast'), + num_timesteps=2**15, + episode_length=128, + num_envs=64, + learning_rate=3e-4, + entropy_cost=1e-2, + discounting=0.95, + unroll_length=5, + batch_size=64, + num_minibatches=8, + num_updates_per_batch=4, + normalize_observations=True, + seed=2, + reward_scaling=10, + normalize_advantage=False) + self.assertGreater(metrics['eval/episode_reward'], 135) + + @parameterized.parameters(True, False) + def testNetworkEncoding(self, normalize_observations): + env = envs.get_environment('fast') + original_inference, params, _ = shac.train( + env, + num_timesteps=128, + episode_length=128, + num_envs=128, + normalize_observations=normalize_observations) + normalize_fn = lambda x, y: x + if normalize_observations: + normalize_fn = running_statistics.normalize + shac_network = shac_networks.make_shac_networks(env.observation_size, + env.action_size, normalize_fn) + inference = shac_networks.make_inference_fn(shac_network) + byte_encoding = pickle.dumps(params) + decoded_params = pickle.loads(byte_encoding) + + # Compute one action. + state = env.reset(jax.random.PRNGKey(0)) + original_action = original_inference(decoded_params)( + state.obs, jax.random.PRNGKey(0))[0] + action = inference(decoded_params)(state.obs, jax.random.PRNGKey(0))[0] + self.assertSequenceEqual(original_action, action) + env.step(state, action) + + +if __name__ == '__main__': + absltest.main() From 63db1f10c1c77173bf17e8ac208e93d6ed20f7c5 Mon Sep 17 00:00:00 2001 From: James Cotton Date: Fri, 18 Nov 2022 23:45:43 +0000 Subject: [PATCH 02/10] Basic SHAC equations Stills needs to have target network --- brax/training/agents/shac/losses.py | 153 +++++++++++++++++++++--- brax/training/agents/shac/train.py | 9 +- brax/training/agents/shac/train_test.py | 11 +- 3 files changed, 144 insertions(+), 29 deletions(-) diff --git a/brax/training/agents/shac/losses.py b/brax/training/agents/shac/losses.py index 056eae1e..8e9820a4 100644 --- a/brax/training/agents/shac/losses.py +++ b/brax/training/agents/shac/losses.py @@ -93,6 +93,117 @@ def compute_vs_minus_v_xs(carry, target_t): return jax.lax.stop_gradient(vs), jax.lax.stop_gradient(advantages) +def compute_policy_loss(truncation: jnp.ndarray, + termination: jnp.ndarray, + rewards: jnp.ndarray, + values: jnp.ndarray, + bootstrap_value: jnp.ndarray, + discount: float = 0.99): + """Calculates the short horizon reward. + + This implements Eq. 5 of 2204.07137. It needs to account for any episodes where + the episode terminates and include the terminal values appopriately. + + Adopted from ppo.losses.compute_gae + + Args: + truncation: A float32 tensor of shape [T, B] with truncation signal. + termination: A float32 tensor of shape [T, B] with termination signal. + rewards: A float32 tensor of shape [T, B] containing rewards generated by + following the behaviour policy. + values: A float32 tensor of shape [T, B] with the value function estimates + wrt. the target policy. + bootstrap_value: A float32 of shape [B] with the value function estimate at + time T. + discount: TD discount. + + Returns: + A scalar loss. + """ + + horizon = rewards.shape[0] + truncation_mask = 1 - truncation + # Append bootstrapped value to get [v1, ..., v_t+1] + values_t_plus_1 = jnp.concatenate([values[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0) + + def sum_step(carry, target_t): + gam, acc = carry + reward, truncation_mask, vtp1, termination = target_t + gam = jnp.where(termination, 1.0, gam * discount) + acc = acc + truncation_mask * jnp.where(termination, vtp1, gam * reward) + return (gam, acc), (None) + + acc = bootstrap_value * discount ** horizon + gam = jnp.ones_like(bootstrap_value) + (_, acc), (_) = jax.lax.scan(sum_step, (gam, acc), + (rewards, truncation_mask, values_t_plus_1, termination)) + + loss = -jnp.mean(acc) / horizon + return loss + + +def compute_target_values(truncation: jnp.ndarray, + termination: jnp.ndarray, + rewards: jnp.ndarray, + values: jnp.ndarray, + bootstrap_value: jnp.ndarray, + discount: float = 0.99, + lambda_: float = 0.95, + td_lambda=True): + """Calculates the target values. + + This implements Eq. 7 of 2204.07137 + https://github.com/NVlabs/DiffRL/blob/main/algorithms/shac.py#L349 + + Args: + truncation: A float32 tensor of shape [T, B] with truncation signal. + termination: A float32 tensor of shape [T, B] with termination signal. + rewards: A float32 tensor of shape [T, B] containing rewards generated by + following the behaviour policy. + values: A float32 tensor of shape [T, B] with the value function estimates + wrt. the target policy. + bootstrap_value: A float32 of shape [B] with the value function estimate at + time T. + discount: TD discount. + + Returns: + A float32 tensor of shape [T, B]. + """ + truncation_mask = 1 - truncation + # Append bootstrapped value to get [v1, ..., v_t+1] + values_t_plus_1 = jnp.concatenate( + [values[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0) + + if td_lambda: + + def compute_v_st(carry, target_t): + Ai, Bi, lam = carry + reward, truncation_mask, vtp1, termination = target_t + # TODO: should figure out how to handle termination + + lam = lam * lambda_ * (1 - termination) + termination + Ai = (1 - termination) * (lam * discount * Ai + discount * vtp1 + (1. - lam) / (1. - lambda_) * reward) + Bi = discount * (vtp1 * termination + Bi * (1.0 - termination)) + reward + vs = (1.0 - lambda_) * Ai + lam * Bi + + return (Ai, Bi, lam), (vs) + + Ai = jnp.ones_like(bootstrap_value) + Bi = jnp.zeros_like(bootstrap_value) + lam = jnp.ones_like(bootstrap_value) + + (_, _, _), (vs) = jax.lax.scan(compute_v_st, (Ai, Bi, lam), + (rewards, truncation_mask, values_t_plus_1, termination), + length=int(truncation_mask.shape[0]), + reverse=True) + + else: + vs = rewards + discount * values_t_plus_1 + + + return jax.lax.stop_gradient(vs) + + def compute_shac_loss( params: SHACNetworkParams, normalizer_params: Any, @@ -102,9 +213,8 @@ def compute_shac_loss( entropy_cost: float = 1e-4, discounting: float = 0.9, reward_scaling: float = 1.0, - gae_lambda: float = 0.95, - clipping_epsilon: float = 0.3, - normalize_advantage: bool = True) -> Tuple[jnp.ndarray, types.Metrics]: + lambda_: float = 0.95, + clipping_epsilon: float = 0.3) -> Tuple[jnp.ndarray, types.Metrics]: """Computes SHAC loss. Args: @@ -118,7 +228,7 @@ def compute_shac_loss( entropy_cost: entropy cost. discounting: discounting, reward_scaling: reward multiplier. - gae_lambda: General advantage estimation lambda. + lambda_: Lambda for TD value updates clipping_epsilon: Policy loss clipping epsilon normalize_advantage: whether to normalize advantage estimate @@ -143,32 +253,39 @@ def compute_shac_loss( truncation = data.extras['state_extras']['truncation'] termination = (1 - data.discount) * (1 - truncation) - target_action_log_probs = parametric_action_distribution.log_prob( - policy_logits, data.extras['policy_extras']['raw_action']) - behaviour_action_log_probs = data.extras['policy_extras']['log_prob'] - - vs, advantages = compute_gae( + # compute policy loss + policy_loss = compute_policy_loss( truncation=truncation, termination=termination, rewards=rewards, values=baseline, bootstrap_value=bootstrap_value, - lambda_=gae_lambda, discount=discounting) - if normalize_advantage: - advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) - rho_s = jnp.exp(target_action_log_probs - behaviour_action_log_probs) - surrogate_loss1 = rho_s * advantages - surrogate_loss2 = jnp.clip(rho_s, 1 - clipping_epsilon, - 1 + clipping_epsilon) * advantages + vs = compute_target_values( + truncation=truncation, + termination=termination, + rewards=rewards, + values=baseline, + bootstrap_value=bootstrap_value, + discount=discounting, + lambda_=lambda_) - policy_loss = -jnp.mean(jnp.minimum(surrogate_loss1, surrogate_loss2)) + vs, advantages = compute_gae( + truncation=truncation, + termination=termination, + rewards=rewards, + values=baseline, + bootstrap_value=bootstrap_value, + lambda_=0.95, + discount=discounting) - # Value function loss v_error = vs - baseline v_loss = jnp.mean(v_error * v_error) * 0.5 * 0.5 + jax.debug.print("SIZE {x}", x=vs.shape) + jax.debug.print("LOSS {loss} MEAN TARGET {targets} V_LOSS {v_loss}", loss=policy_loss, targets=jnp.mean(vs), v_loss=v_loss) + # Entropy reward entropy = jnp.mean(parametric_action_distribution.entropy(policy_logits, rng)) entropy_loss = entropy_cost * -entropy diff --git a/brax/training/agents/shac/train.py b/brax/training/agents/shac/train.py index 46df46d6..d73b2f6e 100644 --- a/brax/training/agents/shac/train.py +++ b/brax/training/agents/shac/train.py @@ -78,12 +78,11 @@ def train(environment: envs.Env, normalize_observations: bool = False, reward_scaling: float = 1., clipping_epsilon: float = .3, - gae_lambda: float = .95, + lambda_: float = .95, deterministic_eval: bool = False, network_factory: types.NetworkFactory[ shac_networks.SHACNetworks] = shac_networks.make_shac_networks, progress_fn: Callable[[int, Metrics], None] = lambda *args: None, - normalize_advantage: bool = True, eval_env: Optional[envs.Env] = None): """SHAC training.""" assert batch_size * num_minibatches % num_envs == 0 @@ -135,9 +134,8 @@ def train(environment: envs.Env, entropy_cost=entropy_cost, discounting=discounting, reward_scaling=reward_scaling, - gae_lambda=gae_lambda, - clipping_epsilon=clipping_epsilon, - normalize_advantage=normalize_advantage) + lambda_=lambda_, + clipping_epsilon=clipping_epsilon) gradient_update_fn = gradients.gradient_update_fn( loss_fn, optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True) @@ -282,6 +280,7 @@ def training_epoch_with_timing( key_envs = jnp.reshape(key_envs, (local_devices_to_use, -1) + key_envs.shape[1:]) env_state = reset_fn(key_envs) + print(f'env_state: {env_state.qp.pos.shape}') if not eval_env: eval_env = env diff --git a/brax/training/agents/shac/train_test.py b/brax/training/agents/shac/train_test.py index 781e2ea0..5dfc790e 100644 --- a/brax/training/agents/shac/train_test.py +++ b/brax/training/agents/shac/train_test.py @@ -32,20 +32,19 @@ def testTrain(self): """Test SHAC with a simple env.""" _, _, metrics = shac.train( envs.get_environment('fast'), - num_timesteps=2**15, + num_timesteps=2**18, episode_length=128, num_envs=64, - learning_rate=3e-4, + learning_rate=3e-5, entropy_cost=1e-2, discounting=0.95, - unroll_length=5, + unroll_length=10, batch_size=64, num_minibatches=8, - num_updates_per_batch=4, + num_updates_per_batch=1, normalize_observations=True, seed=2, - reward_scaling=10, - normalize_advantage=False) + reward_scaling=10) self.assertGreater(metrics['eval/episode_reward'], 135) @parameterized.parameters(True, False) From b7dad1e99e26707dbd206f6134c18a5f76f2fb8e Mon Sep 17 00:00:00 2001 From: James Cotton Date: Sat, 19 Nov 2022 12:23:21 +0000 Subject: [PATCH 03/10] SHAC: start to split actor/critic learning Need to separate the policy learning so we can differentiate through the experience back to the policy network. --- brax/training/agents/shac/losses.py | 113 ++++++------------------ brax/training/agents/shac/networks.py | 11 +-- brax/training/agents/shac/train.py | 45 ++++++---- brax/training/agents/shac/train_test.py | 3 +- 4 files changed, 60 insertions(+), 112 deletions(-) diff --git a/brax/training/agents/shac/losses.py b/brax/training/agents/shac/losses.py index 8e9820a4..4e1dd91d 100644 --- a/brax/training/agents/shac/losses.py +++ b/brax/training/agents/shac/losses.py @@ -34,65 +34,6 @@ class SHACNetworkParams: value: Params -def compute_gae(truncation: jnp.ndarray, - termination: jnp.ndarray, - rewards: jnp.ndarray, - values: jnp.ndarray, - bootstrap_value: jnp.ndarray, - lambda_: float = 1.0, - discount: float = 0.99): - """Calculates the Generalized Advantage Estimation (GAE). - - Args: - truncation: A float32 tensor of shape [T, B] with truncation signal. - termination: A float32 tensor of shape [T, B] with termination signal. - rewards: A float32 tensor of shape [T, B] containing rewards generated by - following the behaviour policy. - values: A float32 tensor of shape [T, B] with the value function estimates - wrt. the target policy. - bootstrap_value: A float32 of shape [B] with the value function estimate at - time T. - lambda_: Mix between 1-step (lambda_=0) and n-step (lambda_=1). Defaults to - lambda_=1. - discount: TD discount. - - Returns: - A float32 tensor of shape [T, B]. Can be used as target to - train a baseline (V(x_t) - vs_t)^2. - A float32 tensor of shape [T, B] of advantages. - """ - - truncation_mask = 1 - truncation - # Append bootstrapped value to get [v1, ..., v_t+1] - values_t_plus_1 = jnp.concatenate( - [values[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0) - deltas = rewards + discount * (1 - termination) * values_t_plus_1 - values - deltas *= truncation_mask - - acc = jnp.zeros_like(bootstrap_value) - vs_minus_v_xs = [] - - def compute_vs_minus_v_xs(carry, target_t): - lambda_, acc = carry - truncation_mask, delta, termination = target_t - acc = delta + discount * (1 - termination) * truncation_mask * lambda_ * acc - return (lambda_, acc), (acc) - - (_, _), (vs_minus_v_xs) = jax.lax.scan( - compute_vs_minus_v_xs, (lambda_, acc), - (truncation_mask, deltas, termination), - length=int(truncation_mask.shape[0]), - reverse=True) - # Add V(x_s) to get v_s. - vs = jnp.add(vs_minus_v_xs, values) - - vs_t_plus_1 = jnp.concatenate( - [vs[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0) - advantages = (rewards + discount * - (1 - termination) * vs_t_plus_1 - values) * truncation_mask - return jax.lax.stop_gradient(vs), jax.lax.stop_gradient(advantages) - - def compute_policy_loss(truncation: jnp.ndarray, termination: jnp.ndarray, rewards: jnp.ndarray, @@ -130,10 +71,10 @@ def sum_step(carry, target_t): gam, acc = carry reward, truncation_mask, vtp1, termination = target_t gam = jnp.where(termination, 1.0, gam * discount) - acc = acc + truncation_mask * jnp.where(termination, vtp1, gam * reward) + acc = acc + truncation_mask * jnp.where(termination, 0, gam * reward) return (gam, acc), (None) - acc = bootstrap_value * discount ** horizon + acc = bootstrap_value * (discount ** horizon) gam = jnp.ones_like(bootstrap_value) (_, acc), (_) = jax.lax.scan(sum_step, (gam, acc), (rewards, truncation_mask, values_t_plus_1, termination)) @@ -149,7 +90,7 @@ def compute_target_values(truncation: jnp.ndarray, bootstrap_value: jnp.ndarray, discount: float = 0.99, lambda_: float = 0.95, - td_lambda=True): + td_lambda=False): """Calculates the target values. This implements Eq. 7 of 2204.07137 @@ -205,7 +146,7 @@ def compute_v_st(carry, target_t): def compute_shac_loss( - params: SHACNetworkParams, + params: Params, normalizer_params: Any, data: types.Transition, rng: jnp.ndarray, @@ -218,7 +159,7 @@ def compute_shac_loss( """Computes SHAC loss. Args: - params: Network parameters, + params: Value network parameters, normalizer_params: Parameters of the normalizer. data: Transition that with leading dimension [B, T]. extra fields required are ['state_extras']['truncation'] ['policy_extras']['raw_action'] @@ -236,18 +177,16 @@ def compute_shac_loss( A tuple (loss, metrics) """ parametric_action_distribution = shac_network.parametric_action_distribution - policy_apply = shac_network.policy_network.apply + #policy_apply = shac_network.policy_network.apply value_apply = shac_network.value_network.apply # Put the time dimension first. data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 0, 1), data) - policy_logits = policy_apply(normalizer_params, params.policy, - data.observation) + #policy_logits = policy_apply(normalizer_params, params.policy, + # data.observation) - baseline = value_apply(normalizer_params, params.value, data.observation) - - bootstrap_value = value_apply(normalizer_params, params.value, - data.next_observation[-1]) + baseline = value_apply(normalizer_params, params, data.observation) + bootstrap_value = value_apply(normalizer_params, params, data.next_observation[-1]) rewards = data.reward * reward_scaling truncation = data.extras['state_extras']['truncation'] @@ -262,6 +201,8 @@ def compute_shac_loss( bootstrap_value=bootstrap_value, discount=discounting) + policy_loss = -jnp.mean(rewards) + vs = compute_target_values( truncation=truncation, termination=termination, @@ -271,26 +212,30 @@ def compute_shac_loss( discount=discounting, lambda_=lambda_) - vs, advantages = compute_gae( - truncation=truncation, - termination=termination, - rewards=rewards, - values=baseline, - bootstrap_value=bootstrap_value, - lambda_=0.95, - discount=discounting) + if False: + from ..ppo.losses import compute_gae + vs, advantages = compute_gae( + truncation=truncation, + termination=termination, + rewards=rewards, + values=baseline, + bootstrap_value=bootstrap_value, + lambda_=0.95, + discount=discounting) v_error = vs - baseline v_loss = jnp.mean(v_error * v_error) * 0.5 * 0.5 - jax.debug.print("SIZE {x}", x=vs.shape) - jax.debug.print("LOSS {loss} MEAN TARGET {targets} V_LOSS {v_loss}", loss=policy_loss, targets=jnp.mean(vs), v_loss=v_loss) + jax.debug.print("LOSS {loss} MEAN TARGET {targets} V_LOSS {v_loss} MEAN_REWARD {x} MEAN BOOTSTRAP {y}", + loss=policy_loss, targets=jnp.mean(vs), v_loss=v_loss, x=jnp.mean(rewards), + y=jnp.mean(bootstrap_value)) # Entropy reward - entropy = jnp.mean(parametric_action_distribution.entropy(policy_logits, rng)) - entropy_loss = entropy_cost * -entropy + #entropy = jnp.mean(parametric_action_distribution.entropy(policy_logits, rng)) + #entropy_loss = entropy_cost * -entropy + entropy_loss = 0 - total_loss = policy_loss + v_loss + entropy_loss + total_loss = policy_loss #+ v_loss + entropy_loss return total_loss, { 'total_loss': total_loss, 'policy_loss': policy_loss, diff --git a/brax/training/agents/shac/networks.py b/brax/training/agents/shac/networks.py index 9240bc0f..8d652834 100644 --- a/brax/training/agents/shac/networks.py +++ b/brax/training/agents/shac/networks.py @@ -44,15 +44,8 @@ def policy(observations: types.Observation, logits = policy_network.apply(*params, observations) if deterministic: return shac_networks.parametric_action_distribution.mode(logits), {} - raw_actions = parametric_action_distribution.sample_no_postprocessing( - logits, key_sample) - log_prob = parametric_action_distribution.log_prob(logits, raw_actions) - postprocessed_actions = parametric_action_distribution.postprocess( - raw_actions) - return postprocessed_actions, { - 'log_prob': log_prob, - 'raw_action': raw_actions - } + return shac_networks.parametric_action_distribution.sample( + logits, key_sample), {} return policy diff --git a/brax/training/agents/shac/train.py b/brax/training/agents/shac/train.py index d73b2f6e..f29555aa 100644 --- a/brax/training/agents/shac/train.py +++ b/brax/training/agents/shac/train.py @@ -49,8 +49,10 @@ @flax.struct.dataclass class TrainingState: """Contains training state for the learner.""" - optimizer_state: optax.OptState - params: shac_losses.SHACNetworkParams + policy_optimizer_state: optax.OptState + policy_params: Params + value_optimizer_state: optax.OptState + value_params: Params normalizer_params: running_statistics.RunningStatisticsState env_steps: jnp.ndarray @@ -66,7 +68,8 @@ def train(environment: envs.Env, num_envs: int = 1, max_devices_per_host: Optional[int] = None, num_eval_envs: int = 128, - learning_rate: float = 1e-4, + actor_learning_rate: float = 1e-3, + critic_learning_rate: float = 1e-4, entropy_cost: float = 1e-4, discounting: float = 0.9, seed: int = 0, @@ -126,7 +129,8 @@ def train(environment: envs.Env, preprocess_observations_fn=normalize) make_policy = shac_networks.make_inference_fn(shac_network) - optimizer = optax.adam(learning_rate=learning_rate) + policy_optimizer = optax.adam(learning_rate=actor_learning_rate) + value_optimizer = optax.adam(learning_rate=critic_learning_rate) loss_fn = functools.partial( shac_losses.compute_shac_loss, @@ -138,7 +142,7 @@ def train(environment: envs.Env, clipping_epsilon=clipping_epsilon) gradient_update_fn = gradients.gradient_update_fn( - loss_fn, optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True) + loss_fn, value_optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True) def minibatch_step( carry, data: types.Transition, @@ -154,7 +158,7 @@ def minibatch_step( return (optimizer_state, params, key), metrics - def sgd_step(carry, unused_t, data: types.Transition, + def critic_sgd_step(carry, unused_t, data: types.Transition, normalizer_params: running_statistics.RunningStatisticsState): optimizer_state, params, key = carry key, key_perm, key_grad = jax.random.split(key, 3) @@ -178,8 +182,10 @@ def training_step( training_state, state, key = carry key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3) + # TODO: this needs to be wrapped in a differentiable function for the + # policy loss policy = make_policy( - (training_state.normalizer_params, training_state.params.policy)) + (training_state.normalizer_params, training_state.policy_params)) def f(carry, unused_t): current_state, current_key = carry @@ -210,13 +216,15 @@ def f(carry, unused_t): (optimizer_state, params, _), metrics = jax.lax.scan( functools.partial( - sgd_step, data=data, normalizer_params=normalizer_params), - (training_state.optimizer_state, training_state.params, key_sgd), (), + critic_sgd_step, data=data, normalizer_params=normalizer_params), + (training_state.value_optimizer_state, training_state.value_params, key_sgd), (), length=num_updates_per_batch) new_training_state = TrainingState( - optimizer_state=optimizer_state, - params=params, + policy_optimizer_state=training_state.policy_optimizer_state, + policy_params=training_state.policy_params, + value_optimizer_state=optimizer_state, + value_params=params, normalizer_params=normalizer_params, env_steps=training_state.env_steps + env_step_per_training_step) return (new_training_state, state, new_key), metrics @@ -263,12 +271,13 @@ def training_epoch_with_timing( key_policy, key_value = jax.random.split(global_key) del global_key - init_params = shac_losses.SHACNetworkParams( - policy=shac_network.policy_network.init(key_policy), - value=shac_network.value_network.init(key_value)) + policy_init_params = shac_network.policy_network.init(key_policy) + value_init_params = shac_network.value_network.init(key_value) training_state = TrainingState( - optimizer_state=optimizer.init(init_params), - params=init_params, + policy_optimizer_state=policy_optimizer.init(policy_init_params), + policy_params=policy_init_params, + value_optimizer_state=value_optimizer.init(value_init_params), + value_params=value_init_params, normalizer_params=running_statistics.init_state( specs.Array((env.observation_size,), jnp.float32)), env_steps=0) @@ -322,7 +331,7 @@ def training_epoch_with_timing( # Run evals. metrics = evaluator.run_evaluation( _unpmap( - (training_state.normalizer_params, training_state.params.policy)), + (training_state.normalizer_params, training_state.policy_params)), training_metrics) logging.info(metrics) progress_fn(current_step, metrics) @@ -334,7 +343,7 @@ def training_epoch_with_timing( # devices. pmap.assert_is_replicated(training_state) params = _unpmap( - (training_state.normalizer_params, training_state.params.policy)) + (training_state.normalizer_params, training_state.policy_params)) logging.info('total steps: %s', total_steps) pmap.synchronize_hosts() return (make_policy, params, metrics) diff --git a/brax/training/agents/shac/train_test.py b/brax/training/agents/shac/train_test.py index 5dfc790e..75ab3b90 100644 --- a/brax/training/agents/shac/train_test.py +++ b/brax/training/agents/shac/train_test.py @@ -35,7 +35,8 @@ def testTrain(self): num_timesteps=2**18, episode_length=128, num_envs=64, - learning_rate=3e-5, + actor_learning_rate=1e-3, + critic_learning_rate=1e-4, entropy_cost=1e-2, discounting=0.95, unroll_length=10, From 9b5111c5be389a160aa506dc320912e7d9a40b24 Mon Sep 17 00:00:00 2001 From: James Cotton Date: Sat, 19 Nov 2022 13:33:19 +0000 Subject: [PATCH 04/10] SHAC: differentiate rewards w.r.t. actions --- brax/training/agents/shac/losses.py | 109 +++++++++++++++++++--------- brax/training/agents/shac/train.py | 90 ++++++++++++++--------- 2 files changed, 130 insertions(+), 69 deletions(-) diff --git a/brax/training/agents/shac/losses.py b/brax/training/agents/shac/losses.py index 4e1dd91d..b2b5983d 100644 --- a/brax/training/agents/shac/losses.py +++ b/brax/training/agents/shac/losses.py @@ -82,6 +82,71 @@ def sum_step(carry, target_t): loss = -jnp.mean(acc) / horizon return loss +def compute_shac_policy_loss( + policy_params: Params, + value_params: Params, + normalizer_params: Any, + data: types.Transition, + rng: jnp.ndarray, + shac_network: shac_networks.SHACNetworks, + entropy_cost: float = 1e-4, + discounting: float = 0.9, + reward_scaling: float = 1.0) -> Tuple[jnp.ndarray, types.Metrics]: + """Computes SHAC critic loss. + + Args: + policy_params: Policy network parameters + value_params: Value network parameters, + normalizer_params: Parameters of the normalizer. + data: Transition that with leading dimension [B, T]. extra fields required + are ['state_extras']['truncation'] ['policy_extras']['raw_action'] + ['policy_extras']['log_prob'] + rng: Random key + shac_network: SHAC networks. + entropy_cost: entropy cost. + discounting: discounting, + reward_scaling: reward multiplier. + + Returns: + A scalar loss + """ + parametric_action_distribution = shac_network.parametric_action_distribution + policy_apply = shac_network.policy_network.apply + value_apply = shac_network.value_network.apply + + # Put the time dimension first. + data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 0, 1), data) + policy_logits = policy_apply(normalizer_params, policy_params, + data.observation) + + baseline = value_apply(normalizer_params, value_params, data.observation) + bootstrap_value = value_apply(normalizer_params, value_params, data.next_observation[-1]) + + rewards = data.reward * reward_scaling + truncation = data.extras['state_extras']['truncation'] + termination = (1 - data.discount) * (1 - truncation) + + # compute policy loss + policy_loss = compute_policy_loss( + truncation=truncation, + termination=termination, + rewards=rewards, + values=baseline, + bootstrap_value=bootstrap_value, + discount=discounting) + + #policy_loss = -jnp.mean(rewards) + + # Entropy reward + entropy = jnp.mean(parametric_action_distribution.entropy(policy_logits, rng)) + entropy_loss = entropy_cost * -entropy + #entropy_loss = 0 + + total_loss = policy_loss + entropy_loss + + return total_loss + + def compute_target_values(truncation: jnp.ndarray, termination: jnp.ndarray, @@ -90,7 +155,7 @@ def compute_target_values(truncation: jnp.ndarray, bootstrap_value: jnp.ndarray, discount: float = 0.99, lambda_: float = 0.95, - td_lambda=False): + td_lambda=True): """Calculates the target values. This implements Eq. 7 of 2204.07137 @@ -145,18 +210,16 @@ def compute_v_st(carry, target_t): return jax.lax.stop_gradient(vs) -def compute_shac_loss( +def compute_shac_critic_loss( params: Params, normalizer_params: Any, data: types.Transition, rng: jnp.ndarray, shac_network: shac_networks.SHACNetworks, - entropy_cost: float = 1e-4, discounting: float = 0.9, reward_scaling: float = 1.0, - lambda_: float = 0.95, - clipping_epsilon: float = 0.3) -> Tuple[jnp.ndarray, types.Metrics]: - """Computes SHAC loss. + lambda_: float = 0.95) -> Tuple[jnp.ndarray, types.Metrics]: + """Computes SHAC critic loss. Args: params: Value network parameters, @@ -176,14 +239,10 @@ def compute_shac_loss( Returns: A tuple (loss, metrics) """ - parametric_action_distribution = shac_network.parametric_action_distribution - #policy_apply = shac_network.policy_network.apply + value_apply = shac_network.value_network.apply - # Put the time dimension first. data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 0, 1), data) - #policy_logits = policy_apply(normalizer_params, params.policy, - # data.observation) baseline = value_apply(normalizer_params, params, data.observation) bootstrap_value = value_apply(normalizer_params, params, data.next_observation[-1]) @@ -192,17 +251,6 @@ def compute_shac_loss( truncation = data.extras['state_extras']['truncation'] termination = (1 - data.discount) * (1 - truncation) - # compute policy loss - policy_loss = compute_policy_loss( - truncation=truncation, - termination=termination, - rewards=rewards, - values=baseline, - bootstrap_value=bootstrap_value, - discount=discounting) - - policy_loss = -jnp.mean(rewards) - vs = compute_target_values( truncation=truncation, termination=termination, @@ -226,19 +274,14 @@ def compute_shac_loss( v_error = vs - baseline v_loss = jnp.mean(v_error * v_error) * 0.5 * 0.5 - jax.debug.print("LOSS {loss} MEAN TARGET {targets} V_LOSS {v_loss} MEAN_REWARD {x} MEAN BOOTSTRAP {y}", - loss=policy_loss, targets=jnp.mean(vs), v_loss=v_loss, x=jnp.mean(rewards), - y=jnp.mean(bootstrap_value)) - - # Entropy reward - #entropy = jnp.mean(parametric_action_distribution.entropy(policy_logits, rng)) - #entropy_loss = entropy_cost * -entropy - entropy_loss = 0 + #jax.debug.print("LOSS {loss} MEAN TARGET {targets} V_LOSS {v_loss} MEAN_REWARD {x} MEAN BOOTSTRAP {y}", + # loss=policy_loss, targets=jnp.mean(vs), v_loss=v_loss, x=jnp.mean(rewards), + # y=jnp.mean(bootstrap_value)) - total_loss = policy_loss #+ v_loss + entropy_loss + total_loss = v_loss return total_loss, { 'total_loss': total_loss, - 'policy_loss': policy_loss, + 'policy_loss': 0, 'v_loss': v_loss, - 'entropy_loss': entropy_loss + 'entropy_loss': 0 } diff --git a/brax/training/agents/shac/train.py b/brax/training/agents/shac/train.py index f29555aa..533b61fb 100644 --- a/brax/training/agents/shac/train.py +++ b/brax/training/agents/shac/train.py @@ -80,7 +80,6 @@ def train(environment: envs.Env, num_evals: int = 1, normalize_observations: bool = False, reward_scaling: float = 1., - clipping_epsilon: float = .3, lambda_: float = .95, deterministic_eval: bool = False, network_factory: types.NetworkFactory[ @@ -132,24 +131,64 @@ def train(environment: envs.Env, policy_optimizer = optax.adam(learning_rate=actor_learning_rate) value_optimizer = optax.adam(learning_rate=critic_learning_rate) - loss_fn = functools.partial( - shac_losses.compute_shac_loss, + value_loss_fn = functools.partial( + shac_losses.compute_shac_critic_loss, shac_network=shac_network, - entropy_cost=entropy_cost, discounting=discounting, reward_scaling=reward_scaling, - lambda_=lambda_, - clipping_epsilon=clipping_epsilon) + lambda_=lambda_) + + value_gradient_update_fn = gradients.gradient_update_fn( + value_loss_fn, value_optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True) + + policy_loss_fn = functools.partial( + shac_losses.compute_shac_policy_loss, + shac_network=shac_network, + entropy_cost=entropy_cost, + discounting=discounting, + reward_scaling=reward_scaling) + + def rollout_loss_fn(policy_params, value_params, normalizer_params, state, key): + policy = make_policy((normalizer_params, policy_params)) + + key, key_loss = jax.random.split(key) + + def f(carry, unused_t): + current_state, current_key = carry + current_key, next_key = jax.random.split(current_key) + next_state, data = acting.generate_unroll( + env, + current_state, + policy, + current_key, + unroll_length, + extra_fields=('truncation',)) + return (next_state, next_key), data + + (state, _), data = jax.lax.scan( + f, (state, key), (), + length=batch_size * num_minibatches // num_envs) + + # Have leading dimentions (batch_size * num_minibatches, unroll_length) + data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 1, 2), data) + data = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), + data) + assert data.discount.shape[1:] == (unroll_length,) - gradient_update_fn = gradients.gradient_update_fn( - loss_fn, value_optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True) + loss = policy_loss_fn(policy_params, value_params, + normalizer_params, data, key_loss) + + return loss, (state, data) + + policy_gradient_update_fn = gradients.gradient_update_fn( + rollout_loss_fn, value_optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True) def minibatch_step( carry, data: types.Transition, normalizer_params: running_statistics.RunningStatisticsState): optimizer_state, params, key = carry key, key_loss = jax.random.split(key) - (_, metrics), params, optimizer_state = gradient_update_fn( + (_, metrics), params, optimizer_state = value_gradient_update_fn( params, normalizer_params, data, @@ -182,31 +221,10 @@ def training_step( training_state, state, key = carry key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3) - # TODO: this needs to be wrapped in a differentiable function for the - # policy loss - policy = make_policy( - (training_state.normalizer_params, training_state.policy_params)) - - def f(carry, unused_t): - current_state, current_key = carry - current_key, next_key = jax.random.split(current_key) - next_state, data = acting.generate_unroll( - env, - current_state, - policy, - current_key, - unroll_length, - extra_fields=('truncation',)) - return (next_state, next_key), data - - (state, _), data = jax.lax.scan( - f, (state, key_generate_unroll), (), - length=batch_size * num_minibatches // num_envs) - # Have leading dimentions (batch_size * num_minibatches, unroll_length) - data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 1, 2), data) - data = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), - data) - assert data.discount.shape[1:] == (unroll_length,) + (_, (state, data)), policy_params, policy_optimizer_state = policy_gradient_update_fn( + training_state.policy_params, training_state.value_params, + training_state.normalizer_params, state, key_generate_unroll, + optimizer_state=training_state.policy_optimizer_state) # Update normalization params and normalize observations. normalizer_params = running_statistics.update( @@ -221,8 +239,8 @@ def f(carry, unused_t): length=num_updates_per_batch) new_training_state = TrainingState( - policy_optimizer_state=training_state.policy_optimizer_state, - policy_params=training_state.policy_params, + policy_optimizer_state=policy_optimizer_state, + policy_params=policy_params, value_optimizer_state=optimizer_state, value_params=params, normalizer_params=normalizer_params, From 727d5bc63c1a348ef882c7704059151156220e27 Mon Sep 17 00:00:00 2001 From: James Cotton Date: Sat, 19 Nov 2022 16:08:14 +0000 Subject: [PATCH 05/10] Envs: fast differentiable environment The original fast is not, due to action>0 as the control signal. Thus APG and other differentiable solvers like SHAC perform poorly. --- brax/envs/__init__.py | 2 ++ brax/envs/fast_differentiable.py | 53 ++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+) create mode 100644 brax/envs/fast_differentiable.py diff --git a/brax/envs/__init__.py b/brax/envs/__init__.py index f3cd0b5c..baa53e4c 100644 --- a/brax/envs/__init__.py +++ b/brax/envs/__init__.py @@ -22,6 +22,7 @@ from brax.envs import acrobot from brax.envs import ant from brax.envs import fast +from brax.envs import fast_differentiable from brax.envs import fetch from brax.envs import grasp from brax.envs import half_cheetah @@ -45,6 +46,7 @@ 'acrobot': acrobot.Acrobot, 'ant': functools.partial(ant.Ant, use_contact_forces=True), 'fast': fast.Fast, + 'fast_differentiable': fast_differentiable.FastDifferentiable, 'fetch': fetch.Fetch, 'grasp': grasp.Grasp, 'halfcheetah': half_cheetah.Halfcheetah, diff --git a/brax/envs/fast_differentiable.py b/brax/envs/fast_differentiable.py new file mode 100644 index 00000000..8d6969f2 --- /dev/null +++ b/brax/envs/fast_differentiable.py @@ -0,0 +1,53 @@ +# Copyright 2022 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Gotta go fast! This trivial Env is meant for unit testing.""" + +import brax +from brax.envs import env +import jax.numpy as jnp + + +class FastDifferentiable(env.Env): + """Trains an agent to go fast.""" + + def __init__(self, **kwargs): + super().__init__(config='dt: .02', **kwargs) + + def reset(self, rng: jnp.ndarray) -> env.State: + zero = jnp.zeros(1) + qp = brax.QP(pos=zero, vel=zero, rot=zero, ang=zero) + obs = jnp.zeros(2) + reward, done = jnp.zeros(2) + return env.State(qp, obs, reward, done) + + def step(self, state: env.State, action: jnp.ndarray) -> env.State: + vel = state.qp.vel + action * (action > 0) * self.sys.config.dt + pos = state.qp.pos + vel * self.sys.config.dt + + qp = state.qp.replace(pos=pos, vel=vel) + obs = jnp.array([pos[0], vel[0]]) + reward = pos[0] + #reward = 1.0 + #reward = action[0] + + return state.replace(qp=qp, obs=obs, reward=reward) + + @property + def observation_size(self): + return 2 + + @property + def action_size(self): + return 1 From 5d35010674b4faa85a36a3e5444feed31d65257c Mon Sep 17 00:00:00 2001 From: James Cotton Date: Sat, 19 Nov 2022 16:09:06 +0000 Subject: [PATCH 06/10] APG: test on fast_differentiable This gives a performance score comparable to other algos. --- brax/training/agents/apg/train_test.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/brax/training/agents/apg/train_test.py b/brax/training/agents/apg/train_test.py index 94d88c16..28992c03 100644 --- a/brax/training/agents/apg/train_test.py +++ b/brax/training/agents/apg/train_test.py @@ -30,15 +30,14 @@ class APGTest(parameterized.TestCase): def testTrain(self): """Test APG with a simple env.""" _, _, metrics = apg.train( - envs.get_environment('fast'), + envs.get_environment('fast_differentiable'), episode_length=128, num_envs=64, num_evals=200, learning_rate=3e-3, normalize_observations=True, ) - # TODO: Can you make this 135? - self.assertGreater(metrics['eval/episode_reward'], 50) + self.assertGreater(metrics['eval/episode_reward'], 135) @parameterized.parameters(True, False) def testNetworkEncoding(self, normalize_observations): From 44830aa2046d7783583086b06e0043d4830b9315 Mon Sep 17 00:00:00 2001 From: James Cotton Date: Sat, 19 Nov 2022 16:13:31 +0000 Subject: [PATCH 07/10] SHAC: passing unit tests now --- brax/training/agents/shac/losses.py | 116 ++++++++---------------- brax/training/agents/shac/networks.py | 12 ++- brax/training/agents/shac/train.py | 18 ++-- brax/training/agents/shac/train_test.py | 8 +- 4 files changed, 62 insertions(+), 92 deletions(-) diff --git a/brax/training/agents/shac/losses.py b/brax/training/agents/shac/losses.py index b2b5983d..fe16c3e3 100644 --- a/brax/training/agents/shac/losses.py +++ b/brax/training/agents/shac/losses.py @@ -34,54 +34,6 @@ class SHACNetworkParams: value: Params -def compute_policy_loss(truncation: jnp.ndarray, - termination: jnp.ndarray, - rewards: jnp.ndarray, - values: jnp.ndarray, - bootstrap_value: jnp.ndarray, - discount: float = 0.99): - """Calculates the short horizon reward. - - This implements Eq. 5 of 2204.07137. It needs to account for any episodes where - the episode terminates and include the terminal values appopriately. - - Adopted from ppo.losses.compute_gae - - Args: - truncation: A float32 tensor of shape [T, B] with truncation signal. - termination: A float32 tensor of shape [T, B] with termination signal. - rewards: A float32 tensor of shape [T, B] containing rewards generated by - following the behaviour policy. - values: A float32 tensor of shape [T, B] with the value function estimates - wrt. the target policy. - bootstrap_value: A float32 of shape [B] with the value function estimate at - time T. - discount: TD discount. - - Returns: - A scalar loss. - """ - - horizon = rewards.shape[0] - truncation_mask = 1 - truncation - # Append bootstrapped value to get [v1, ..., v_t+1] - values_t_plus_1 = jnp.concatenate([values[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0) - - def sum_step(carry, target_t): - gam, acc = carry - reward, truncation_mask, vtp1, termination = target_t - gam = jnp.where(termination, 1.0, gam * discount) - acc = acc + truncation_mask * jnp.where(termination, 0, gam * reward) - return (gam, acc), (None) - - acc = bootstrap_value * (discount ** horizon) - gam = jnp.ones_like(bootstrap_value) - (_, acc), (_) = jax.lax.scan(sum_step, (gam, acc), - (rewards, truncation_mask, values_t_plus_1, termination)) - - loss = -jnp.mean(acc) / horizon - return loss - def compute_shac_policy_loss( policy_params: Params, value_params: Params, @@ -94,6 +46,9 @@ def compute_shac_policy_loss( reward_scaling: float = 1.0) -> Tuple[jnp.ndarray, types.Metrics]: """Computes SHAC critic loss. + This implements Eq. 5 of 2204.07137. It needs to account for any episodes where + the episode terminates and include the terminal values appopriately. + Args: policy_params: Policy network parameters value_params: Value network parameters, @@ -110,41 +65,61 @@ def compute_shac_policy_loss( Returns: A scalar loss """ + parametric_action_distribution = shac_network.parametric_action_distribution policy_apply = shac_network.policy_network.apply value_apply = shac_network.value_network.apply # Put the time dimension first. data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 0, 1), data) - policy_logits = policy_apply(normalizer_params, policy_params, - data.observation) - baseline = value_apply(normalizer_params, value_params, data.observation) - bootstrap_value = value_apply(normalizer_params, value_params, data.next_observation[-1]) + # this is a redundant computation with the critic loss function + # but there isn't a straighforward way to get these values when + # they are used in that step + values = value_apply(normalizer_params, value_params, data.observation) + terminal_values = value_apply(normalizer_params, value_params, data.next_observation[-1]) rewards = data.reward * reward_scaling truncation = data.extras['state_extras']['truncation'] termination = (1 - data.discount) * (1 - truncation) - # compute policy loss - policy_loss = compute_policy_loss( - truncation=truncation, - termination=termination, - rewards=rewards, - values=baseline, - bootstrap_value=bootstrap_value, - discount=discounting) + horizon = rewards.shape[0] + + def sum_step(carry, target_t): + gam, acc = carry + reward, v, truncation, termination = target_t + acc = acc + jnp.where(truncation + termination, gam * v, gam * reward) + gam = jnp.where(termination, 1.0, gam * discounting) + return (gam, acc), (acc) + + acc = terminal_values * (discounting ** horizon) * (1-termination[-1]) * (1-truncation[-1]) + jax.debug.print('acc shape: {x}', x=acc.shape) + gam = jnp.ones_like(terminal_values) + (_, acc), (temp) = jax.lax.scan(sum_step, (gam, acc), + (rewards, values, truncation, termination)) + + policy_loss = -jnp.mean(acc) / horizon + + # inspect the data for one of the rollouts + jax.debug.print('obs={o}, obs_next={n}, values={v}, reward={r}, truncation={t}, terminal={s}', + v=values[:, 0], o=data.observation[:,0], r=data.reward[:,0], + t=truncation[:, 0], s=termination[:,0], n=data.next_observation[:, 0]) - #policy_loss = -jnp.mean(rewards) + jax.debug.print('loss={l}, r={r}', l=policy_loss, r=temp[:,0]) # Entropy reward + policy_logits = policy_apply(normalizer_params, policy_params, + data.observation) entropy = jnp.mean(parametric_action_distribution.entropy(policy_logits, rng)) entropy_loss = entropy_cost * -entropy - #entropy_loss = 0 total_loss = policy_loss + entropy_loss - return total_loss + return total_loss, { + 'total_loss': total_loss, + 'policy_loss': policy_loss, + 'entropy_loss': entropy_loss + } @@ -206,7 +181,6 @@ def compute_v_st(carry, target_t): else: vs = rewards + discount * values_t_plus_1 - return jax.lax.stop_gradient(vs) @@ -260,23 +234,9 @@ def compute_shac_critic_loss( discount=discounting, lambda_=lambda_) - if False: - from ..ppo.losses import compute_gae - vs, advantages = compute_gae( - truncation=truncation, - termination=termination, - rewards=rewards, - values=baseline, - bootstrap_value=bootstrap_value, - lambda_=0.95, - discount=discounting) - v_error = vs - baseline v_loss = jnp.mean(v_error * v_error) * 0.5 * 0.5 - #jax.debug.print("LOSS {loss} MEAN TARGET {targets} V_LOSS {v_loss} MEAN_REWARD {x} MEAN BOOTSTRAP {y}", - # loss=policy_loss, targets=jnp.mean(vs), v_loss=v_loss, x=jnp.mean(rewards), - # y=jnp.mean(bootstrap_value)) total_loss = v_loss return total_loss, { diff --git a/brax/training/agents/shac/networks.py b/brax/training/agents/shac/networks.py index 8d652834..c4e325af 100644 --- a/brax/training/agents/shac/networks.py +++ b/brax/training/agents/shac/networks.py @@ -44,8 +44,16 @@ def policy(observations: types.Observation, logits = policy_network.apply(*params, observations) if deterministic: return shac_networks.parametric_action_distribution.mode(logits), {} - return shac_networks.parametric_action_distribution.sample( - logits, key_sample), {} + raw_actions = parametric_action_distribution.sample_no_postprocessing( + logits, key_sample) + log_prob = parametric_action_distribution.log_prob(logits, raw_actions) + postprocessed_actions = parametric_action_distribution.postprocess( + raw_actions) + return postprocessed_actions, { + 'log_prob': log_prob, + 'raw_action': raw_actions + } + return policy diff --git a/brax/training/agents/shac/train.py b/brax/training/agents/shac/train.py index 533b61fb..f8f19f68 100644 --- a/brax/training/agents/shac/train.py +++ b/brax/training/agents/shac/train.py @@ -175,13 +175,13 @@ def f(carry, unused_t): data) assert data.discount.shape[1:] == (unroll_length,) - loss = policy_loss_fn(policy_params, value_params, + loss, metrics = policy_loss_fn(policy_params, value_params, normalizer_params, data, key_loss) - return loss, (state, data) + return loss, (state, data, metrics) policy_gradient_update_fn = gradients.gradient_update_fn( - rollout_loss_fn, value_optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True) + rollout_loss_fn, policy_optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True) def minibatch_step( carry, data: types.Transition, @@ -221,7 +221,7 @@ def training_step( training_state, state, key = carry key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3) - (_, (state, data)), policy_params, policy_optimizer_state = policy_gradient_update_fn( + (policy_loss, (state, data, policy_metrics)), policy_params, policy_optimizer_state = policy_gradient_update_fn( training_state.policy_params, training_state.value_params, training_state.normalizer_params, state, key_generate_unroll, optimizer_state=training_state.policy_optimizer_state) @@ -232,18 +232,20 @@ def training_step( data.observation, pmap_axis_name=_PMAP_AXIS_NAME) - (optimizer_state, params, _), metrics = jax.lax.scan( + (value_optimizer_state, value_params, _), metrics = jax.lax.scan( functools.partial( critic_sgd_step, data=data, normalizer_params=normalizer_params), (training_state.value_optimizer_state, training_state.value_params, key_sgd), (), length=num_updates_per_batch) + metrics.update(policy_metrics) + new_training_state = TrainingState( policy_optimizer_state=policy_optimizer_state, policy_params=policy_params, - value_optimizer_state=optimizer_state, - value_params=params, - normalizer_params=normalizer_params, + value_optimizer_state=value_optimizer_state, + value_params=value_params, + normalizer_params=training_state.normalizer_params, env_steps=training_state.env_steps + env_step_per_training_step) return (new_training_state, state, new_key), metrics diff --git a/brax/training/agents/shac/train_test.py b/brax/training/agents/shac/train_test.py index 75ab3b90..b3d5c79a 100644 --- a/brax/training/agents/shac/train_test.py +++ b/brax/training/agents/shac/train_test.py @@ -31,12 +31,12 @@ class SHACTest(parameterized.TestCase): def testTrain(self): """Test SHAC with a simple env.""" _, _, metrics = shac.train( - envs.get_environment('fast'), - num_timesteps=2**18, + envs.get_environment('fast_differentiable'), + num_timesteps=2**15, episode_length=128, num_envs=64, - actor_learning_rate=1e-3, - critic_learning_rate=1e-4, + actor_learning_rate=1.5e-2, + critic_learning_rate=1e-3, entropy_cost=1e-2, discounting=0.95, unroll_length=10, From 4000c95544fc2a4ce219d5ef4363d52b5a31ddbe Mon Sep 17 00:00:00 2001 From: James Cotton Date: Sat, 19 Nov 2022 18:55:53 +0000 Subject: [PATCH 08/10] SHAC: add target network --- brax/training/agents/shac/losses.py | 152 +++++++++++----------------- brax/training/agents/shac/train.py | 12 ++- 2 files changed, 69 insertions(+), 95 deletions(-) diff --git a/brax/training/agents/shac/losses.py b/brax/training/agents/shac/losses.py index fe16c3e3..68ad34a3 100644 --- a/brax/training/agents/shac/losses.py +++ b/brax/training/agents/shac/losses.py @@ -83,29 +83,32 @@ def compute_shac_policy_loss( truncation = data.extras['state_extras']['truncation'] termination = (1 - data.discount) * (1 - truncation) - horizon = rewards.shape[0] + # Append terminal values to get [v1, ..., v_t+1] + values_t_plus_1 = jnp.concatenate( + [values[1:], jnp.expand_dims(terminal_values, 0)], axis=0) + # jax implementation of https://github.com/NVlabs/DiffRL/blob/a4c0dd1696d3c3b885ce85a3cb64370b580cb913/algorithms/shac.py#L227 def sum_step(carry, target_t): - gam, acc = carry - reward, v, truncation, termination = target_t - acc = acc + jnp.where(truncation + termination, gam * v, gam * reward) + gam, rew_acc = carry + reward, v, termination = target_t + + # clean up gamma and rew_acc for done envs, otherwise update + rew_acc = jnp.where(termination, 0, rew_acc + gam * reward) gam = jnp.where(termination, 1.0, gam * discounting) - return (gam, acc), (acc) - acc = terminal_values * (discounting ** horizon) * (1-termination[-1]) * (1-truncation[-1]) - jax.debug.print('acc shape: {x}', x=acc.shape) - gam = jnp.ones_like(terminal_values) - (_, acc), (temp) = jax.lax.scan(sum_step, (gam, acc), - (rewards, values, truncation, termination)) + return (gam, rew_acc), (gam, rew_acc) - policy_loss = -jnp.mean(acc) / horizon + rew_acc = jnp.zeros_like(terminal_values) + gam = jnp.ones_like(terminal_values) + (gam, last_rew_acc), (gam_acc, rew_acc) = jax.lax.scan(sum_step, (gam, rew_acc), + (rewards, values, termination)) - # inspect the data for one of the rollouts - jax.debug.print('obs={o}, obs_next={n}, values={v}, reward={r}, truncation={t}, terminal={s}', - v=values[:, 0], o=data.observation[:,0], r=data.reward[:,0], - t=truncation[:, 0], s=termination[:,0], n=data.next_observation[:, 0]) + policy_loss = jnp.sum(-last_rew_acc - gam * terminal_values) + # for trials that are truncated (i.e. hit the episode length) include reward for + # terminal state. otherwise, the trial was aborted and should receive zero additional + policy_loss = policy_loss + jnp.sum((-rew_acc - gam_acc * jnp.where(truncation, values_t_plus_1, 0)) * termination) + policy_loss = policy_loss / values.shape[0] / values.shape[1] - jax.debug.print('loss={l}, r={r}', l=policy_loss, r=temp[:,0]) # Entropy reward policy_logits = policy_apply(normalizer_params, policy_params, @@ -122,68 +125,6 @@ def sum_step(carry, target_t): } - -def compute_target_values(truncation: jnp.ndarray, - termination: jnp.ndarray, - rewards: jnp.ndarray, - values: jnp.ndarray, - bootstrap_value: jnp.ndarray, - discount: float = 0.99, - lambda_: float = 0.95, - td_lambda=True): - """Calculates the target values. - - This implements Eq. 7 of 2204.07137 - https://github.com/NVlabs/DiffRL/blob/main/algorithms/shac.py#L349 - - Args: - truncation: A float32 tensor of shape [T, B] with truncation signal. - termination: A float32 tensor of shape [T, B] with termination signal. - rewards: A float32 tensor of shape [T, B] containing rewards generated by - following the behaviour policy. - values: A float32 tensor of shape [T, B] with the value function estimates - wrt. the target policy. - bootstrap_value: A float32 of shape [B] with the value function estimate at - time T. - discount: TD discount. - - Returns: - A float32 tensor of shape [T, B]. - """ - truncation_mask = 1 - truncation - # Append bootstrapped value to get [v1, ..., v_t+1] - values_t_plus_1 = jnp.concatenate( - [values[1:], jnp.expand_dims(bootstrap_value, 0)], axis=0) - - if td_lambda: - - def compute_v_st(carry, target_t): - Ai, Bi, lam = carry - reward, truncation_mask, vtp1, termination = target_t - # TODO: should figure out how to handle termination - - lam = lam * lambda_ * (1 - termination) + termination - Ai = (1 - termination) * (lam * discount * Ai + discount * vtp1 + (1. - lam) / (1. - lambda_) * reward) - Bi = discount * (vtp1 * termination + Bi * (1.0 - termination)) + reward - vs = (1.0 - lambda_) * Ai + lam * Bi - - return (Ai, Bi, lam), (vs) - - Ai = jnp.ones_like(bootstrap_value) - Bi = jnp.zeros_like(bootstrap_value) - lam = jnp.ones_like(bootstrap_value) - - (_, _, _), (vs) = jax.lax.scan(compute_v_st, (Ai, Bi, lam), - (rewards, truncation_mask, values_t_plus_1, termination), - length=int(truncation_mask.shape[0]), - reverse=True) - - else: - vs = rewards + discount * values_t_plus_1 - - return jax.lax.stop_gradient(vs) - - def compute_shac_critic_loss( params: Params, normalizer_params: Any, @@ -192,9 +133,13 @@ def compute_shac_critic_loss( shac_network: shac_networks.SHACNetworks, discounting: float = 0.9, reward_scaling: float = 1.0, - lambda_: float = 0.95) -> Tuple[jnp.ndarray, types.Metrics]: + lambda_: float = 0.95, + td_lambda: bool = True) -> Tuple[jnp.ndarray, types.Metrics]: """Computes SHAC critic loss. + This implements Eq. 7 of 2204.07137 + https://github.com/NVlabs/DiffRL/blob/main/algorithms/shac.py#L349 + Args: params: Value network parameters, normalizer_params: Parameters of the normalizer. @@ -207,8 +152,7 @@ def compute_shac_critic_loss( discounting: discounting, reward_scaling: reward multiplier. lambda_: Lambda for TD value updates - clipping_epsilon: Policy loss clipping epsilon - normalize_advantage: whether to normalize advantage estimate + td_lambda: whether to use a TD-Lambda value target Returns: A tuple (loss, metrics) @@ -218,25 +162,47 @@ def compute_shac_critic_loss( data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 0, 1), data) - baseline = value_apply(normalizer_params, params, data.observation) - bootstrap_value = value_apply(normalizer_params, params, data.next_observation[-1]) + values = value_apply(normalizer_params, params, data.observation) + terminal_value = value_apply(normalizer_params, params, data.next_observation[-1]) rewards = data.reward * reward_scaling truncation = data.extras['state_extras']['truncation'] termination = (1 - data.discount) * (1 - truncation) - vs = compute_target_values( - truncation=truncation, - termination=termination, - rewards=rewards, - values=baseline, - bootstrap_value=bootstrap_value, - discount=discounting, - lambda_=lambda_) + # Append terminal values to get [v1, ..., v_t+1] + values_t_plus_1 = jnp.concatenate( + [values[1:], jnp.expand_dims(terminal_value, 0)], axis=0) + + # compute target values + if td_lambda: + + def compute_v_st(carry, target_t): + Ai, Bi, lam = carry + reward, vtp1, termination = target_t + + reward = reward * termination + + lam = lam * lambda_ * (1 - termination) + termination + Ai = (1 - termination) * (lam * discounting * Ai + discounting * vtp1 + (1. - lam) / (1. - lambda_) * reward) + Bi = discounting * (vtp1 * termination + Bi * (1.0 - termination)) + reward + vs = (1.0 - lambda_) * Ai + lam * Bi + + return (Ai, Bi, lam), (vs) + + Ai = jnp.ones_like(terminal_value) + Bi = jnp.zeros_like(terminal_value) + lam = jnp.ones_like(terminal_value) + (_, _, _), (vs) = jax.lax.scan(compute_v_st, (Ai, Bi, lam), + (rewards, values_t_plus_1, termination), + length=int(termination.shape[0]), + reverse=True) + + else: + vs = rewards + discounting * values_t_plus_1 - v_error = vs - baseline - v_loss = jnp.mean(v_error * v_error) * 0.5 * 0.5 + target_values = jax.lax.stop_gradient(vs) + v_loss = jnp.mean((target_values - values) ** 2) total_loss = v_loss return total_loss, { diff --git a/brax/training/agents/shac/train.py b/brax/training/agents/shac/train.py index f8f19f68..9510afce 100644 --- a/brax/training/agents/shac/train.py +++ b/brax/training/agents/shac/train.py @@ -53,6 +53,7 @@ class TrainingState: policy_params: Params value_optimizer_state: optax.OptState value_params: Params + target_value_params: Params normalizer_params: running_statistics.RunningStatisticsState env_steps: jnp.ndarray @@ -80,6 +81,7 @@ def train(environment: envs.Env, num_evals: int = 1, normalize_observations: bool = False, reward_scaling: float = 1., + tau: float = 0.005, # this is 1-alpha from the original paper lambda_: float = .95, deterministic_eval: bool = False, network_factory: types.NetworkFactory[ @@ -222,7 +224,7 @@ def training_step( key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3) (policy_loss, (state, data, policy_metrics)), policy_params, policy_optimizer_state = policy_gradient_update_fn( - training_state.policy_params, training_state.value_params, + training_state.policy_params, training_state.target_value_params, training_state.normalizer_params, state, key_generate_unroll, optimizer_state=training_state.policy_optimizer_state) @@ -238,6 +240,10 @@ def training_step( (training_state.value_optimizer_state, training_state.value_params, key_sgd), (), length=num_updates_per_batch) + target_value_params = jax.tree_util.tree_map( + lambda x, y: x * (1 - tau) + y * tau, training_state.target_value_params, + value_params) + metrics.update(policy_metrics) new_training_state = TrainingState( @@ -245,6 +251,7 @@ def training_step( policy_params=policy_params, value_optimizer_state=value_optimizer_state, value_params=value_params, + target_value_params=target_value_params, normalizer_params=training_state.normalizer_params, env_steps=training_state.env_steps + env_step_per_training_step) return (new_training_state, state, new_key), metrics @@ -298,6 +305,7 @@ def training_epoch_with_timing( policy_params=policy_init_params, value_optimizer_state=value_optimizer.init(value_init_params), value_params=value_init_params, + target_value_params=value_init_params, normalizer_params=running_statistics.init_state( specs.Array((env.observation_size,), jnp.float32)), env_steps=0) @@ -329,7 +337,7 @@ def training_epoch_with_timing( if process_id == 0 and num_evals > 1: metrics = evaluator.run_evaluation( _unpmap( - (training_state.normalizer_params, training_state.params.policy)), + (training_state.normalizer_params, training_state.policy_params)), training_metrics={}) logging.info(metrics) progress_fn(0, metrics) From 8ff700541b048e7de98fdac3eb99b374a0ca759e Mon Sep 17 00:00:00 2001 From: James Cotton Date: Sun, 20 Nov 2022 15:20:44 +0000 Subject: [PATCH 09/10] SHAC: layer norm and gradient clipping Starting to see progress training the ant environment. --- brax/training/agents/shac/losses.py | 8 +++----- brax/training/agents/shac/networks.py | 6 ++++-- brax/training/agents/shac/train.py | 13 +++++++++---- brax/training/networks.py | 11 +++++++++-- 4 files changed, 25 insertions(+), 13 deletions(-) diff --git a/brax/training/agents/shac/losses.py b/brax/training/agents/shac/losses.py index 68ad34a3..c2290cc3 100644 --- a/brax/training/agents/shac/losses.py +++ b/brax/training/agents/shac/losses.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Proximal policy optimization training. +"""Short-Horizon Actor Critic. -See: https://arxiv.org/pdf/1707.06347.pdf +See: https://arxiv.org/pdf/2204.07137.pdf """ from typing import Any, Tuple @@ -46,8 +46,7 @@ def compute_shac_policy_loss( reward_scaling: float = 1.0) -> Tuple[jnp.ndarray, types.Metrics]: """Computes SHAC critic loss. - This implements Eq. 5 of 2204.07137. It needs to account for any episodes where - the episode terminates and include the terminal values appopriately. + This implements Eq. 5 of 2204.07137. Args: policy_params: Policy network parameters @@ -129,7 +128,6 @@ def compute_shac_critic_loss( params: Params, normalizer_params: Any, data: types.Transition, - rng: jnp.ndarray, shac_network: shac_networks.SHACNetworks, discounting: float = 0.9, reward_scaling: float = 1.0, diff --git a/brax/training/agents/shac/networks.py b/brax/training/agents/shac/networks.py index c4e325af..47a4a0b1 100644 --- a/brax/training/agents/shac/networks.py +++ b/brax/training/agents/shac/networks.py @@ -76,12 +76,14 @@ def make_shac_networks( observation_size, preprocess_observations_fn=preprocess_observations_fn, hidden_layer_sizes=policy_hidden_layer_sizes, - activation=activation) + activation=activation, + layer_norm=True) value_network = networks.make_value_network( observation_size, preprocess_observations_fn=preprocess_observations_fn, hidden_layer_sizes=value_hidden_layer_sizes, - activation=activation) + activation=activation, + layer_norm=True) return SHACNetworks( policy_network=policy_network, diff --git a/brax/training/agents/shac/train.py b/brax/training/agents/shac/train.py index 9510afce..e4f621fe 100644 --- a/brax/training/agents/shac/train.py +++ b/brax/training/agents/shac/train.py @@ -130,8 +130,14 @@ def train(environment: envs.Env, preprocess_observations_fn=normalize) make_policy = shac_networks.make_inference_fn(shac_network) - policy_optimizer = optax.adam(learning_rate=actor_learning_rate) - value_optimizer = optax.adam(learning_rate=critic_learning_rate) + policy_optimizer = optax.chain( + optax.clip(1.0), + optax.adam(learning_rate=actor_learning_rate, b1=0.7, b2=0.95) + ) + value_optimizer = optax.chain( + optax.clip(1.0), + optax.adam(learning_rate=critic_learning_rate, b1=0.7, b2=0.95) + ) value_loss_fn = functools.partial( shac_losses.compute_shac_critic_loss, @@ -184,6 +190,7 @@ def f(carry, unused_t): policy_gradient_update_fn = gradients.gradient_update_fn( rollout_loss_fn, policy_optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True) + policy_gradient_update_fn = jax.jit(policy_gradient_update_fn) def minibatch_step( carry, data: types.Transition, @@ -194,7 +201,6 @@ def minibatch_step( params, normalizer_params, data, - key_loss, optimizer_state=optimizer_state) return (optimizer_state, params, key), metrics @@ -317,7 +323,6 @@ def training_epoch_with_timing( key_envs = jnp.reshape(key_envs, (local_devices_to_use, -1) + key_envs.shape[1:]) env_state = reset_fn(key_envs) - print(f'env_state: {env_state.qp.pos.shape}') if not eval_env: eval_env = env diff --git a/brax/training/networks.py b/brax/training/networks.py index 5856360a..903d1008 100644 --- a/brax/training/networks.py +++ b/brax/training/networks.py @@ -41,6 +41,7 @@ class MLP(linen.Module): kernel_init: Initializer = jax.nn.initializers.lecun_uniform() activate_final: bool = False bias: bool = True + layer_norm: bool = True @linen.compact def __call__(self, data: jnp.ndarray): @@ -52,6 +53,8 @@ def __call__(self, data: jnp.ndarray): kernel_init=self.kernel_init, use_bias=self.bias)( hidden) + if self.layer_norm: + hidden = linen.LayerNorm()(hidden) if i != len(self.layer_sizes) - 1 or self.activate_final: hidden = self.activation(hidden) return hidden @@ -86,11 +89,13 @@ def make_policy_network( preprocess_observations_fn: types.PreprocessObservationFn = types .identity_observation_preprocessor, hidden_layer_sizes: Sequence[int] = (256, 256), - activation: ActivationFn = linen.relu) -> FeedForwardNetwork: + activation: ActivationFn = linen.relu, + layer_norm: bool = False) -> FeedForwardNetwork: """Creates a policy network.""" policy_module = MLP( layer_sizes=list(hidden_layer_sizes) + [param_size], activation=activation, + layer_norm=layer_norm, kernel_init=jax.nn.initializers.lecun_uniform()) def apply(processor_params, policy_params, obs): @@ -107,11 +112,13 @@ def make_value_network( preprocess_observations_fn: types.PreprocessObservationFn = types .identity_observation_preprocessor, hidden_layer_sizes: Sequence[int] = (256, 256), - activation: ActivationFn = linen.relu) -> FeedForwardNetwork: + activation: ActivationFn = linen.relu, + layer_norm: bool = False) -> FeedForwardNetwork: """Creates a policy network.""" value_module = MLP( layer_sizes=list(hidden_layer_sizes) + [1], activation=activation, + layer_norm=layer_norm, kernel_init=jax.nn.initializers.lecun_uniform()) def apply(processor_params, policy_params, obs): From afec266ccb78761761bf2773db6b67b07129c3a7 Mon Sep 17 00:00:00 2001 From: James Cotton Date: Sun, 20 Nov 2022 20:01:26 +0000 Subject: [PATCH 10/10] SHAC: tweak layer norm --- brax/training/agents/ppo/networks.py | 9 ++++++--- brax/training/agents/shac/losses.py | 5 ++--- brax/training/agents/shac/networks.py | 7 ++++--- brax/training/agents/shac/train.py | 4 +++- brax/training/networks.py | 4 ++-- 5 files changed, 17 insertions(+), 12 deletions(-) diff --git a/brax/training/agents/ppo/networks.py b/brax/training/agents/ppo/networks.py index 4631cb4a..084dee30 100644 --- a/brax/training/agents/ppo/networks.py +++ b/brax/training/agents/ppo/networks.py @@ -66,7 +66,8 @@ def make_ppo_networks( .identity_observation_preprocessor, policy_hidden_layer_sizes: Sequence[int] = (32,) * 4, value_hidden_layer_sizes: Sequence[int] = (256,) * 5, - activation: networks.ActivationFn = linen.swish) -> PPONetworks: + activation: networks.ActivationFn = linen.swish, + layer_norm: bool = False) -> PPONetworks: """Make PPO networks with preprocessor.""" parametric_action_distribution = distribution.NormalTanhDistribution( event_size=action_size) @@ -75,12 +76,14 @@ def make_ppo_networks( observation_size, preprocess_observations_fn=preprocess_observations_fn, hidden_layer_sizes=policy_hidden_layer_sizes, - activation=activation) + activation=activation, + layer_norm=layer_norm) value_network = networks.make_value_network( observation_size, preprocess_observations_fn=preprocess_observations_fn, hidden_layer_sizes=value_hidden_layer_sizes, - activation=activation) + activation=activation, + layer_norm=layer_norm) return PPONetworks( policy_network=policy_network, diff --git a/brax/training/agents/shac/losses.py b/brax/training/agents/shac/losses.py index c2290cc3..692536c3 100644 --- a/brax/training/agents/shac/losses.py +++ b/brax/training/agents/shac/losses.py @@ -89,7 +89,7 @@ def compute_shac_policy_loss( # jax implementation of https://github.com/NVlabs/DiffRL/blob/a4c0dd1696d3c3b885ce85a3cb64370b580cb913/algorithms/shac.py#L227 def sum_step(carry, target_t): gam, rew_acc = carry - reward, v, termination = target_t + reward, termination = target_t # clean up gamma and rew_acc for done envs, otherwise update rew_acc = jnp.where(termination, 0, rew_acc + gam * reward) @@ -100,7 +100,7 @@ def sum_step(carry, target_t): rew_acc = jnp.zeros_like(terminal_values) gam = jnp.ones_like(terminal_values) (gam, last_rew_acc), (gam_acc, rew_acc) = jax.lax.scan(sum_step, (gam, rew_acc), - (rewards, values, termination)) + (rewards, termination)) policy_loss = jnp.sum(-last_rew_acc - gam * terminal_values) # for trials that are truncated (i.e. hit the episode length) include reward for @@ -118,7 +118,6 @@ def sum_step(carry, target_t): total_loss = policy_loss + entropy_loss return total_loss, { - 'total_loss': total_loss, 'policy_loss': policy_loss, 'entropy_loss': entropy_loss } diff --git a/brax/training/agents/shac/networks.py b/brax/training/agents/shac/networks.py index 47a4a0b1..bd30d702 100644 --- a/brax/training/agents/shac/networks.py +++ b/brax/training/agents/shac/networks.py @@ -67,7 +67,8 @@ def make_shac_networks( .identity_observation_preprocessor, policy_hidden_layer_sizes: Sequence[int] = (32,) * 4, value_hidden_layer_sizes: Sequence[int] = (256,) * 5, - activation: networks.ActivationFn = linen.swish) -> SHACNetworks: + activation: networks.ActivationFn = linen.elu, + layer_norm: bool = True) -> SHACNetworks: """Make SHAC networks with preprocessor.""" parametric_action_distribution = distribution.NormalTanhDistribution( event_size=action_size) @@ -77,13 +78,13 @@ def make_shac_networks( preprocess_observations_fn=preprocess_observations_fn, hidden_layer_sizes=policy_hidden_layer_sizes, activation=activation, - layer_norm=True) + layer_norm=layer_norm) value_network = networks.make_value_network( observation_size, preprocess_observations_fn=preprocess_observations_fn, hidden_layer_sizes=value_hidden_layer_sizes, activation=activation, - layer_norm=True) + layer_norm=layer_norm) return SHACNetworks( policy_network=policy_network, diff --git a/brax/training/agents/shac/train.py b/brax/training/agents/shac/train.py index e4f621fe..a8bfeafd 100644 --- a/brax/training/agents/shac/train.py +++ b/brax/training/agents/shac/train.py @@ -83,6 +83,7 @@ def train(environment: envs.Env, reward_scaling: float = 1., tau: float = 0.005, # this is 1-alpha from the original paper lambda_: float = .95, + td_lambda: bool = True, deterministic_eval: bool = False, network_factory: types.NetworkFactory[ shac_networks.SHACNetworks] = shac_networks.make_shac_networks, @@ -144,7 +145,8 @@ def train(environment: envs.Env, shac_network=shac_network, discounting=discounting, reward_scaling=reward_scaling, - lambda_=lambda_) + lambda_=lambda_, + td_lambda=td_lambda) value_gradient_update_fn = gradients.gradient_update_fn( value_loss_fn, value_optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True) diff --git a/brax/training/networks.py b/brax/training/networks.py index 903d1008..404e73a9 100644 --- a/brax/training/networks.py +++ b/brax/training/networks.py @@ -53,10 +53,10 @@ def __call__(self, data: jnp.ndarray): kernel_init=self.kernel_init, use_bias=self.bias)( hidden) - if self.layer_norm: - hidden = linen.LayerNorm()(hidden) if i != len(self.layer_sizes) - 1 or self.activate_final: hidden = self.activation(hidden) + if self.layer_norm: + hidden = linen.LayerNorm()(hidden) return hidden