From afec266ccb78761761bf2773db6b67b07129c3a7 Mon Sep 17 00:00:00 2001 From: James Cotton Date: Sun, 20 Nov 2022 20:01:26 +0000 Subject: [PATCH] SHAC: tweak layer norm --- brax/training/agents/ppo/networks.py | 9 ++++++--- brax/training/agents/shac/losses.py | 5 ++--- brax/training/agents/shac/networks.py | 7 ++++--- brax/training/agents/shac/train.py | 4 +++- brax/training/networks.py | 4 ++-- 5 files changed, 17 insertions(+), 12 deletions(-) diff --git a/brax/training/agents/ppo/networks.py b/brax/training/agents/ppo/networks.py index 4631cb4a..084dee30 100644 --- a/brax/training/agents/ppo/networks.py +++ b/brax/training/agents/ppo/networks.py @@ -66,7 +66,8 @@ def make_ppo_networks( .identity_observation_preprocessor, policy_hidden_layer_sizes: Sequence[int] = (32,) * 4, value_hidden_layer_sizes: Sequence[int] = (256,) * 5, - activation: networks.ActivationFn = linen.swish) -> PPONetworks: + activation: networks.ActivationFn = linen.swish, + layer_norm: bool = False) -> PPONetworks: """Make PPO networks with preprocessor.""" parametric_action_distribution = distribution.NormalTanhDistribution( event_size=action_size) @@ -75,12 +76,14 @@ def make_ppo_networks( observation_size, preprocess_observations_fn=preprocess_observations_fn, hidden_layer_sizes=policy_hidden_layer_sizes, - activation=activation) + activation=activation, + layer_norm=layer_norm) value_network = networks.make_value_network( observation_size, preprocess_observations_fn=preprocess_observations_fn, hidden_layer_sizes=value_hidden_layer_sizes, - activation=activation) + activation=activation, + layer_norm=layer_norm) return PPONetworks( policy_network=policy_network, diff --git a/brax/training/agents/shac/losses.py b/brax/training/agents/shac/losses.py index c2290cc3..692536c3 100644 --- a/brax/training/agents/shac/losses.py +++ b/brax/training/agents/shac/losses.py @@ -89,7 +89,7 @@ def compute_shac_policy_loss( # jax implementation of https://github.com/NVlabs/DiffRL/blob/a4c0dd1696d3c3b885ce85a3cb64370b580cb913/algorithms/shac.py#L227 def sum_step(carry, target_t): gam, rew_acc = carry - reward, v, termination = target_t + reward, termination = target_t # clean up gamma and rew_acc for done envs, otherwise update rew_acc = jnp.where(termination, 0, rew_acc + gam * reward) @@ -100,7 +100,7 @@ def sum_step(carry, target_t): rew_acc = jnp.zeros_like(terminal_values) gam = jnp.ones_like(terminal_values) (gam, last_rew_acc), (gam_acc, rew_acc) = jax.lax.scan(sum_step, (gam, rew_acc), - (rewards, values, termination)) + (rewards, termination)) policy_loss = jnp.sum(-last_rew_acc - gam * terminal_values) # for trials that are truncated (i.e. hit the episode length) include reward for @@ -118,7 +118,6 @@ def sum_step(carry, target_t): total_loss = policy_loss + entropy_loss return total_loss, { - 'total_loss': total_loss, 'policy_loss': policy_loss, 'entropy_loss': entropy_loss } diff --git a/brax/training/agents/shac/networks.py b/brax/training/agents/shac/networks.py index 47a4a0b1..bd30d702 100644 --- a/brax/training/agents/shac/networks.py +++ b/brax/training/agents/shac/networks.py @@ -67,7 +67,8 @@ def make_shac_networks( .identity_observation_preprocessor, policy_hidden_layer_sizes: Sequence[int] = (32,) * 4, value_hidden_layer_sizes: Sequence[int] = (256,) * 5, - activation: networks.ActivationFn = linen.swish) -> SHACNetworks: + activation: networks.ActivationFn = linen.elu, + layer_norm: bool = True) -> SHACNetworks: """Make SHAC networks with preprocessor.""" parametric_action_distribution = distribution.NormalTanhDistribution( event_size=action_size) @@ -77,13 +78,13 @@ def make_shac_networks( preprocess_observations_fn=preprocess_observations_fn, hidden_layer_sizes=policy_hidden_layer_sizes, activation=activation, - layer_norm=True) + layer_norm=layer_norm) value_network = networks.make_value_network( observation_size, preprocess_observations_fn=preprocess_observations_fn, hidden_layer_sizes=value_hidden_layer_sizes, activation=activation, - layer_norm=True) + layer_norm=layer_norm) return SHACNetworks( policy_network=policy_network, diff --git a/brax/training/agents/shac/train.py b/brax/training/agents/shac/train.py index e4f621fe..a8bfeafd 100644 --- a/brax/training/agents/shac/train.py +++ b/brax/training/agents/shac/train.py @@ -83,6 +83,7 @@ def train(environment: envs.Env, reward_scaling: float = 1., tau: float = 0.005, # this is 1-alpha from the original paper lambda_: float = .95, + td_lambda: bool = True, deterministic_eval: bool = False, network_factory: types.NetworkFactory[ shac_networks.SHACNetworks] = shac_networks.make_shac_networks, @@ -144,7 +145,8 @@ def train(environment: envs.Env, shac_network=shac_network, discounting=discounting, reward_scaling=reward_scaling, - lambda_=lambda_) + lambda_=lambda_, + td_lambda=td_lambda) value_gradient_update_fn = gradients.gradient_update_fn( value_loss_fn, value_optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True) diff --git a/brax/training/networks.py b/brax/training/networks.py index 903d1008..404e73a9 100644 --- a/brax/training/networks.py +++ b/brax/training/networks.py @@ -53,10 +53,10 @@ def __call__(self, data: jnp.ndarray): kernel_init=self.kernel_init, use_bias=self.bias)( hidden) - if self.layer_norm: - hidden = linen.LayerNorm()(hidden) if i != len(self.layer_sizes) - 1 or self.activate_final: hidden = self.activation(hidden) + if self.layer_norm: + hidden = linen.LayerNorm()(hidden) return hidden