Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696400506
Change-Id: Ia917fe590c842a999c377dfaf30b341d19367401
  • Loading branch information
Brax Team authored and btaba committed Nov 14, 2024
1 parent eb66604 commit f43727e
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 50 deletions.
8 changes: 4 additions & 4 deletions brax/training/agents/ppo/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ class PPONetworks:
def make_inference_fn(ppo_networks: PPONetworks):
"""Creates params and inference function for the PPO agent."""

def make_policy(params: types.PolicyParams,
deterministic: bool = False) -> types.Policy:
def make_policy(
params: types.Params, deterministic: bool = False
) -> types.Policy:
policy_network = ppo_networks.policy_network
parametric_action_distribution = ppo_networks.parametric_action_distribution

def policy(observations: types.Observation,
key_sample: PRNGKey) -> Tuple[types.Action, types.Extra]:
# Discard the value function.
param_subset = (params[0], params[1].policy)
param_subset = (params[0], params[1]) # normalizer and policy params
logits = policy_network.apply(*param_subset, observations)
if deterministic:
return ppo_networks.parametric_action_distribution.mode(logits), {}
Expand Down
40 changes: 29 additions & 11 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,11 @@ def training_step(
training_state, state, key = carry
key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3)

policy = make_policy(
(training_state.normalizer_params, training_state.params))
policy = make_policy((
training_state.normalizer_params,
training_state.params.policy,
training_state.params.value,
))

def f(carry, unused_t):
current_state, current_key = carry
Expand Down Expand Up @@ -402,7 +405,11 @@ def training_epoch_with_timing(
if num_timesteps == 0:
return (
make_policy,
(training_state.normalizer_params, training_state.params),
(
training_state.normalizer_params,
training_state.params.policy,
training_state.params.value,
),
{},
)

Expand Down Expand Up @@ -436,9 +443,13 @@ def training_epoch_with_timing(
metrics = {}
if process_id == 0 and num_evals > 1:
metrics = evaluator.run_evaluation(
_unpmap(
(training_state.normalizer_params, training_state.params)),
training_metrics={})
_unpmap((
training_state.normalizer_params,
training_state.params.policy,
training_state.params.value,
)),
training_metrics={},
)
logging.info(metrics)
progress_fn(0, metrics)

Expand Down Expand Up @@ -466,9 +477,13 @@ def training_epoch_with_timing(
if process_id == 0:
# Run evals.
metrics = evaluator.run_evaluation(
_unpmap(
(training_state.normalizer_params, training_state.params)),
training_metrics)
_unpmap((
training_state.normalizer_params,
training_state.params.policy,
training_state.params.value,
)),
training_metrics,
)
logging.info(metrics)
progress_fn(current_step, metrics)
params = _unpmap(
Expand All @@ -482,8 +497,11 @@ def training_epoch_with_timing(
# If there was no mistakes the training_state should still be identical on all
# devices.
pmap.assert_is_replicated(training_state)
params = _unpmap(
(training_state.normalizer_params, training_state.params))
params = _unpmap((
training_state.normalizer_params,
training_state.params.policy,
training_state.params.value,
))
logging.info('total steps: %s', total_steps)
pmap.synchronize_hosts()
return (make_policy, params, metrics)
79 changes: 44 additions & 35 deletions brax/training/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@
from brax.training.agents.ars import train as ars
from brax.training.agents.es import train as es
from brax.training.agents.ppo import train as ppo
from brax.training.agents.sac import networks as sac_networks
from brax.training.agents.sac import train as sac
from brax.v1 import envs as envs_v1
from brax.v1.io import html as html_v1
from brax.v1.io import npy_file
import jax

import mediapy as media
from etils import epath

FLAGS = flags.FLAGS

Expand All @@ -43,14 +42,12 @@

# TODO move npy_file to v2.

flags.DEFINE_bool('use_v2', True, 'Use Brax v2.')
flags.DEFINE_enum(
'backend',
'mjx',
['mjx', 'spring', 'generalized', 'positional'],
'The physics backend to use.',
)
flags.DEFINE_bool('legacy_spring', False, 'Brax v1 backend.')
flags.DEFINE_integer('total_env_steps', 50000000,
'Number of env steps to run training for.')
flags.DEFINE_integer('num_evals', 10, 'How many times to run an eval.')
Expand Down Expand Up @@ -104,6 +101,7 @@
'grad_updates_per_step', 1,
'How many SAC gradient updates to run per one step in the '
'environment.')
flags.DEFINE_bool('q_network_layer_norm', False, 'Critic network layer norm.')
# PPO hps.
flags.DEFINE_float('gae_lambda', .95, 'General advantage estimation lambda.')
flags.DEFINE_float('clipping_epsilon', .3, 'Policy loss clipping epsilon.')
Expand All @@ -119,30 +117,41 @@
'Std of a random noise added by ARS.')
flags.DEFINE_float('reward_shift', 0.,
'A reward shift to get rid of "stay alive" bonus.')

# ARS hps.
flags.DEFINE_integer('policy_updates', None,
'Number of policy updates in APG.')
# Wrap.
flags.DEFINE_bool('playground_dm_control_suite', False,
'Wrap the environment for MuJoco Playground.')


def main(unused_argv):

if FLAGS.use_v2:
get_environment = functools.partial(
envs.get_environment, backend=FLAGS.backend
)
def get_env_factory():
"""Returns a function that creates an environment."""
if FLAGS.playground_dm_control_suite:
get_environment = lambda *args, **kwargs: mjp.wrapper.BraxEnvWrapper(
mjp.dm_control_suite.load(*args, **kwargs))
else:
get_environment = functools.partial(
envs_v1.get_environment, legacy_spring=FLAGS.legacy_spring
envs.get_environment, backend=FLAGS.backend
)
return get_environment


def main(unused_argv):

get_environment = get_env_factory()
with metrics.Writer(FLAGS.logdir) as writer:
writer.write_hparams({
'num_evals': FLAGS.num_evals,
'num_envs': FLAGS.num_envs,
'total_env_steps': FLAGS.total_env_steps
})
if FLAGS.learner == 'sac':
network_factory = sac_networks.make_sac_networks
if FLAGS.q_network_layer_norm:
network_factory = functools.partial(
sac_networks.make_sac_networks, q_network_layer_norm=True
)
make_policy, params, _ = sac.train(
environment=get_environment(FLAGS.env),
num_envs=FLAGS.num_envs,
Expand All @@ -153,6 +162,7 @@ def main(unused_argv):
batch_size=FLAGS.batch_size,
min_replay_size=FLAGS.min_replay_size,
max_replay_size=FLAGS.max_replay_size,
network_factory=network_factory,
learning_rate=FLAGS.learning_rate,
discounting=FLAGS.discounting,
max_devices_per_host=FLAGS.max_devices_per_host,
Expand Down Expand Up @@ -239,11 +249,7 @@ def main(unused_argv):
model.save_params(path, params)

# Output an episode trajectory.
if FLAGS.use_v2:
env = envs.create(FLAGS.env, backend=FLAGS.backend)
else:
env = envs_v1.create(FLAGS.env, legacy_spring=FLAGS.legacy_spring)

env = get_environment(FLAGS.env)
@jax.jit
def jit_next_state(state, key):
new_key, tmp_key = jax.random.split(key)
Expand All @@ -253,13 +259,11 @@ def jit_next_state(state, key):
def do_rollout(rng):
rng, env_key = jax.random.split(rng)
state = env.reset(env_key)
states = []
while not state.done:
if isinstance(env, envs.Env):
states.append(state.pipeline_state)
else:
states.append(state.qp)
t, states = 0, []
while not state.done and t < FLAGS.episode_length:
states.append(state)
state, _, rng = jit_next_state(state, rng)
t += 1
return states, rng

trajectories = []
Expand All @@ -268,20 +272,25 @@ def do_rollout(rng):
qps, rng = do_rollout(rng)
trajectories.append(qps)

if hasattr(env, 'sys'):
video_path = ''
if hasattr(env, 'sys') and hasattr(env.sys, 'link_names'):
for i in range(FLAGS.num_videos):
video_path = f'{FLAGS.logdir}/saved_videos/trajectory_{i:04d}.html'
html.save(
video_path,
env.sys.tree_replace({'opt.timestep': env.dt}),
[t.pipeline_state for t in trajectories[i]],
) # pytype: disable=wrong-arg-types
elif hasattr(env, 'render'):
for i in range(FLAGS.num_videos):
html_path = f'{FLAGS.logdir}/saved_videos/trajectory_{i:04d}.html'
if isinstance(env, envs.Env):
html.save(html_path, env.sys.tree_replace({'opt.timestep': env.dt}), trajectories[i])
else:
html_v1.save_html(html_path, env.sys, trajectories[i], make_dir=True)
path_ = epath.Path(f'{FLAGS.logdir}/saved_videos/trajectory_{i:04d}.mp4')
path_.parent.mkdir(parents=True)
frames = env.render(trajectories[i])
media.write_video(path_, frames, fps=1.0 / env.dt)
video_path = path_.as_posix()
elif FLAGS.num_videos > 0:
logging.warn('Cannot save videos for non physics environments.')

for i in range(FLAGS.num_trajectories_npy):
qp_path = f'{FLAGS.logdir}/saved_qps/trajectory_{i:04d}.npy'
npy_file.save(qp_path, trajectories[i], make_dir=True)



if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions docs/release-notes/next-release.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
* Add boolean `wrap_env` to all brax `train` functions, which optionally wraps the env for training, or uses the env as is.
* Fix bug in PPO train to return loaded checkpoint when `num_timesteps` is 0.
* Add `layer_norm` to `make_q_network` and set `layer_norm` to `True` in `make_sace_networks` Q Network.
* Change PPO train function to return both value and policy network params, rather than just policy params.

0 comments on commit f43727e

Please sign in to comment.