From 3162b5938c437d8a558d084091c00232fc248b89 Mon Sep 17 00:00:00 2001 From: Kate Baumli Date: Thu, 23 Jul 2020 04:18:39 -0700 Subject: [PATCH] Make AtariTorso and RecurrentActor work without batch dimensions. PiperOrigin-RevId: 322760491 Change-Id: I13ef2a252cc323a7da898954f774b549f4dbe1fd --- acme/agents/jax/actors.py | 7 ++++--- acme/agents/jax/actors_test.py | 12 +++++++----- acme/jax/networks/atari.py | 14 +++++++++++--- acme/jax/utils.py | 4 ++++ 4 files changed, 26 insertions(+), 11 deletions(-) diff --git a/acme/agents/jax/actors.py b/acme/agents/jax/actors.py index b8f7dc2898..471a44ce3d 100644 --- a/acme/agents/jax/actors.py +++ b/acme/agents/jax/actors.py @@ -58,6 +58,7 @@ def __init__( def select_action(self, observation: types.NestedArray) -> types.NestedArray: key = next(self._rng) + # TODO(b/161332815): Make JAX Actor work with batched or unbatched inputs. observation = utils.add_batch_dim(observation) action = self._policy(self._client.params, key, observation) return utils.to_numpy_squeeze(action) @@ -101,11 +102,11 @@ def select_action(self, observation: types.NestedArray) -> types.NestedArray: action, new_state = self._recurrent_policy( self._client.params, key=next(self._rng), - observation=utils.add_batch_dim(observation), + observation=observation, core_state=self._state) self._prev_state = self._state # Keep previous state to save in replay. self._state = new_state # Keep new state for next policy call. - return utils.to_numpy_squeeze(action) + return utils.to_numpy(action) def observe_first(self, timestep: dm_env.TimeStep): if self._adder: @@ -115,7 +116,7 @@ def observe_first(self, timestep: dm_env.TimeStep): def observe(self, action: types.NestedArray, next_timestep: dm_env.TimeStep): if self._adder: - numpy_state = utils.to_numpy_squeeze((self._prev_state)) + numpy_state = utils.to_numpy(self._prev_state) self._adder.add(action, next_timestep, extras=(numpy_state,)) def update(self): diff --git a/acme/agents/jax/actors_test.py b/acme/agents/jax/actors_test.py index 6c5683174e..a2acf3703f 100644 --- a/acme/agents/jax/actors_test.py +++ b/acme/agents/jax/actors_test.py @@ -14,7 +14,7 @@ # limitations under the License. """Tests for actors.""" -from typing import Tuple +from typing import Optional, Tuple from absl.testing import absltest from acme import environment_loop @@ -85,14 +85,16 @@ def test_recurrent(self): @_transform_without_rng def network(inputs: jnp.ndarray, state: hk.LSTMState): - return hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)])(inputs, state) + return hk.DeepRNN([lambda x: jnp.reshape(x, [-1]), + hk.LSTM(output_size)])(inputs, state) @_transform_without_rng - def initial_state(batch_size: int): - network = hk.DeepRNN([hk.Flatten(), hk.LSTM(output_size)]) + def initial_state(batch_size: Optional[int] = None): + network = hk.DeepRNN([lambda x: jnp.reshape(x, [-1]), + hk.LSTM(output_size)]) return network.initial_state(batch_size) - initial_state = initial_state.apply(initial_state.init(next(rng), 1), 1) + initial_state = initial_state.apply(initial_state.init(next(rng))) params = network.init(next(rng), obs, initial_state) def policy( diff --git a/acme/jax/networks/atari.py b/acme/jax/networks/atari.py index f1705dc5da..28a3a91bb7 100644 --- a/acme/jax/networks/atari.py +++ b/acme/jax/networks/atari.py @@ -52,12 +52,20 @@ def __init__(self): hk.Conv2D(64, [4, 4], 2), jax.nn.relu, hk.Conv2D(64, [3, 3], 1), - jax.nn.relu, - hk.Flatten(), + jax.nn.relu ]) def __call__(self, inputs: Images) -> jnp.ndarray: - return self._network(inputs) + inputs_rank = jnp.ndim(inputs) + batched_inputs = inputs_rank == 4 + if inputs_rank < 3 or inputs_rank > 4: + raise ValueError('Expected input BHWC or HWC. Got rank %d' % inputs_rank) + + outputs = self._network(inputs) + + if batched_inputs: + return jnp.reshape(outputs, [outputs.shape[0], -1]) # [B, D] + return jnp.reshape(outputs, [-1]) # [D] def dqn_atari_network(num_actions: int) -> base.QNetwork: diff --git a/acme/jax/utils.py b/acme/jax/utils.py index 8a79587695..8237de2ab0 100644 --- a/acme/jax/utils.py +++ b/acme/jax/utils.py @@ -65,6 +65,10 @@ def to_numpy_squeeze(values: types.Nest) -> types.NestedArray: return tree_util.tree_map(lambda x: np.array(x).squeeze(axis=0), values) +def to_numpy(values: types.Nest) -> types.NestedArray: + return tree_util.tree_map(np.array, values) + + def fetch_devicearray(values: types.Nest) -> types.Nest: """Fetches and converts any DeviceArrays in `values`."""