From 334a618cf761e8b701c615232907ff13984dcce0 Mon Sep 17 00:00:00 2001 From: Adrian Ziegler Date: Thu, 26 Nov 2020 21:49:31 +0100 Subject: [PATCH 1/7] Extend environment API to support seeding --- src/garage/_environment.py | 19 +++++++++++++++++++ src/garage/envs/dm_control/dm_control_env.py | 7 +++++++ src/garage/envs/grid_world_env.py | 7 +++++++ src/garage/envs/gym_env.py | 7 +++++++ src/garage/envs/metaworld_set_task_env.py | 7 +++++++ src/garage/envs/point_env.py | 8 ++++++++ 6 files changed, 55 insertions(+) diff --git a/src/garage/_environment.py b/src/garage/_environment.py index 78bf7b0f61..49eaad1d87 100644 --- a/src/garage/_environment.py +++ b/src/garage/_environment.py @@ -159,6 +159,8 @@ class Environment(abc.ABC): +-----------------------+ | visualize() | +-----------------------+ + | seed() | + +-----------------------+ | close() | +-----------------------+ @@ -350,6 +352,16 @@ def _validate_render_mode(self, mode): 'got render mode {} instead.'.format( self.render_modes, mode)) + @abc.abstractmethod + def seed(self, seed): + """Sets environment seeds. + + This method should set all seeds specific to the environment library. + + Args: + seed (int): The seed value to set + """ + def __del__(self): """Environment destructor.""" self.close() @@ -452,6 +464,13 @@ def visualize(self): """Creates a visualization of the wrapped environment.""" self._env.visualize() + def seed(self, seed): + """Sets all environment seeds. + + Args: + seed (int): The seed value to set + """ + def close(self): """Close the wrapped env.""" self._env.close() diff --git a/src/garage/envs/dm_control/dm_control_env.py b/src/garage/envs/dm_control/dm_control_env.py index 5c47b4634f..98d2dd6b70 100644 --- a/src/garage/envs/dm_control/dm_control_env.py +++ b/src/garage/envs/dm_control/dm_control_env.py @@ -184,6 +184,13 @@ def visualize(self): self._viewer = DmControlViewer(title=title) self._viewer.launch(self._env) + def seed(self, seed): + """Sets all environment seeds. + + Args: + seed (int): The seed value to set + """ + def close(self): """Close the environment.""" if self._viewer: diff --git a/src/garage/envs/grid_world_env.py b/src/garage/envs/grid_world_env.py index af624e7b76..40d8f12559 100644 --- a/src/garage/envs/grid_world_env.py +++ b/src/garage/envs/grid_world_env.py @@ -184,6 +184,13 @@ def render(self, mode): def visualize(self): """Creates a visualization of the environment.""" + def seed(self, seed): + """Sets all environment seeds. + + Args: + seed (int): The seed value to set + """ + def close(self): """Close the env.""" diff --git a/src/garage/envs/gym_env.py b/src/garage/envs/gym_env.py index 321fe0ecaa..1a79efdf72 100644 --- a/src/garage/envs/gym_env.py +++ b/src/garage/envs/gym_env.py @@ -288,6 +288,13 @@ def visualize(self): self._env.render(mode='human') self._visualize = True + def seed(self, seed): + """Sets all environment seeds. + + Args: + seed (int): The seed value to set + """ + def close(self): """Close the wrapped env.""" self._close_viewer_window() diff --git a/src/garage/envs/metaworld_set_task_env.py b/src/garage/envs/metaworld_set_task_env.py index e5d5ad5fe1..5d9f6be1c8 100644 --- a/src/garage/envs/metaworld_set_task_env.py +++ b/src/garage/envs/metaworld_set_task_env.py @@ -251,6 +251,13 @@ def visualize(self): """Creates a visualization of the wrapped environment.""" self._current_env.visualize() + def seed(self, seed): + """Sets all environment seeds. + + Args: + seed (int): The seed value to set + """ + def close(self): """Close the wrapped env.""" for env in self._envs.values(): diff --git a/src/garage/envs/point_env.py b/src/garage/envs/point_env.py index 8d4f2fc0ad..47262db4fe 100644 --- a/src/garage/envs/point_env.py +++ b/src/garage/envs/point_env.py @@ -182,6 +182,14 @@ def visualize(self): def close(self): """Close the env.""" + def seed(self, seed): + """Sets all environment seeds. + + Args: + seed (int): The seed value to set + + """ + # pylint: disable=no-self-use def sample_tasks(self, num_tasks): """Sample a list of `num_tasks` tasks. From 5417cd1dc22aef89fdab47d0ca2fc2869a9897a5 Mon Sep 17 00:00:00 2001 From: Adrian Ziegler Date: Thu, 26 Nov 2020 21:52:23 +0100 Subject: [PATCH 2/7] Set seeds for Gym envs --- src/garage/envs/gym_env.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/garage/envs/gym_env.py b/src/garage/envs/gym_env.py index 1a79efdf72..2940aee3e7 100644 --- a/src/garage/envs/gym_env.py +++ b/src/garage/envs/gym_env.py @@ -294,6 +294,8 @@ def seed(self, seed): Args: seed (int): The seed value to set """ + self._env.seed(seed) + self.action_space.seed(seed) def close(self): """Close the wrapped env.""" From 4ce934cb0897fe6c79c5328206e6bd8913f82fc8 Mon Sep 17 00:00:00 2001 From: Adrian Ziegler Date: Thu, 26 Nov 2020 22:03:11 +0100 Subject: [PATCH 3/7] Set env seeds whenever environment is changed --- src/garage/sampler/_functions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/garage/sampler/_functions.py b/src/garage/sampler/_functions.py index ac6545b063..d8140859f0 100644 --- a/src/garage/sampler/_functions.py +++ b/src/garage/sampler/_functions.py @@ -1,5 +1,6 @@ """Functions used by multiple Samplers or Workers.""" from garage import Environment +from garage.experiment import deterministic from garage.sampler.env_update import EnvUpdate @@ -33,6 +34,7 @@ def _apply_env_update(old_env, env_update): elif isinstance(env_update, Environment): if old_env is not None: old_env.close() + env_update.seed(deterministic.get_seed()) return env_update, True else: raise TypeError('Unknown environment update type.') From 70ec94e7e3937312aee22cda42c8e5f2a63be285 Mon Sep 17 00:00:00 2001 From: Adrian Ziegler Date: Thu, 26 Nov 2020 22:05:31 +0100 Subject: [PATCH 4/7] Add test for determinism of LocalSampler --- tests/garage/sampler/test_local_sampler.py | 35 +++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/tests/garage/sampler/test_local_sampler.py b/tests/garage/sampler/test_local_sampler.py index 653fa46cfa..f3a5edb23f 100644 --- a/tests/garage/sampler/test_local_sampler.py +++ b/tests/garage/sampler/test_local_sampler.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from garage.envs import PointEnv +from garage.envs import GymEnv, PointEnv from garage.experiment.task_sampler import SetTaskSampler from garage.np.policies import FixedPolicy, ScriptedPolicy from garage.sampler import LocalSampler, WorkerFactory @@ -103,3 +103,36 @@ def test_no_seed(): sampler = LocalSampler.from_worker_factory(workers, policy, env) episodes = sampler.obtain_samples(0, 160, policy) assert sum(episodes.lengths) >= 160 + + +def test_deterministic_on_policy_sampling(): + max_episode_length = 1 + env1 = GymEnv('LunarLander-v2') + env2 = GymEnv('LunarLander-v2') + # Fix the action sequence + env1.action_space.seed(10) + env2.action_space.seed(10) + policy1 = FixedPolicy(env1.spec, + scripted_actions=[ + env1.action_space.sample() + for _ in range(max_episode_length) + ]) + policy2 = FixedPolicy(env2.spec, + scripted_actions=[ + env2.action_space.sample() + for _ in range(max_episode_length) + ]) + n_workers = 1 + worker1 = WorkerFactory(seed=10, + max_episode_length=max_episode_length, + n_workers=n_workers) + worker2 = WorkerFactory(seed=10, + max_episode_length=max_episode_length, + n_workers=n_workers) + sampler1 = LocalSampler.from_worker_factory(worker1, policy1, env1) + sampler2 = LocalSampler.from_worker_factory(worker2, policy2, env2) + episodes1 = sampler1.obtain_samples(0, 1, policy1) + episodes2 = sampler2.obtain_samples(0, 1, policy2) + assert np.array_equal(episodes1.observations, episodes2.observations) + assert np.array_equal(episodes1.next_observations, + episodes2.next_observations) From 0a0c019f7c318f5e30004b5f143c10dac837b31f Mon Sep 17 00:00:00 2001 From: Adrian Ziegler Date: Wed, 9 Dec 2020 13:01:55 +0100 Subject: [PATCH 5/7] Set seed in env wrapper --- src/garage/_environment.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/garage/_environment.py b/src/garage/_environment.py index 49eaad1d87..0425610e64 100644 --- a/src/garage/_environment.py +++ b/src/garage/_environment.py @@ -470,6 +470,7 @@ def seed(self, seed): Args: seed (int): The seed value to set """ + self._env.seed() def close(self): """Close the wrapped env.""" From 207bdd78ba5dfedc12f0bde6ba88bce1fbc7e31c Mon Sep 17 00:00:00 2001 From: Adrian Ziegler Date: Wed, 9 Dec 2020 15:39:22 +0100 Subject: [PATCH 6/7] Add seeding for dm control envs + test --- src/garage/envs/dm_control/dm_control_env.py | 3 ++ tests/garage/sampler/test_local_sampler.py | 48 ++++++++++++++++++-- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/src/garage/envs/dm_control/dm_control_env.py b/src/garage/envs/dm_control/dm_control_env.py index 98d2dd6b70..69764f5968 100644 --- a/src/garage/envs/dm_control/dm_control_env.py +++ b/src/garage/envs/dm_control/dm_control_env.py @@ -190,6 +190,9 @@ def seed(self, seed): Args: seed (int): The seed value to set """ + # pylint: disable=protected-access + self._env._task._random = np.random.RandomState(seed) + self.action_space.seed(seed) def close(self): """Close the environment.""" diff --git a/tests/garage/sampler/test_local_sampler.py b/tests/garage/sampler/test_local_sampler.py index f3a5edb23f..4494705b83 100644 --- a/tests/garage/sampler/test_local_sampler.py +++ b/tests/garage/sampler/test_local_sampler.py @@ -2,6 +2,7 @@ import pytest from garage.envs import GymEnv, PointEnv +from garage.envs.dm_control import DMControlEnv from garage.experiment.task_sampler import SetTaskSampler from garage.np.policies import FixedPolicy, ScriptedPolicy from garage.sampler import LocalSampler, WorkerFactory @@ -105,8 +106,8 @@ def test_no_seed(): assert sum(episodes.lengths) >= 160 -def test_deterministic_on_policy_sampling(): - max_episode_length = 1 +def test_deterministic_on_policy_sampling_gym_env(): + max_episode_length = 10 env1 = GymEnv('LunarLander-v2') env2 = GymEnv('LunarLander-v2') # Fix the action sequence @@ -122,7 +123,7 @@ def test_deterministic_on_policy_sampling(): env2.action_space.sample() for _ in range(max_episode_length) ]) - n_workers = 1 + n_workers = 2 worker1 = WorkerFactory(seed=10, max_episode_length=max_episode_length, n_workers=n_workers) @@ -131,8 +132,45 @@ def test_deterministic_on_policy_sampling(): n_workers=n_workers) sampler1 = LocalSampler.from_worker_factory(worker1, policy1, env1) sampler2 = LocalSampler.from_worker_factory(worker2, policy2, env2) - episodes1 = sampler1.obtain_samples(0, 1, policy1) - episodes2 = sampler2.obtain_samples(0, 1, policy2) + episodes1 = sampler1.obtain_samples(0, 10, policy1) + episodes2 = sampler2.obtain_samples(0, 10, policy2) assert np.array_equal(episodes1.observations, episodes2.observations) assert np.array_equal(episodes1.next_observations, episodes2.next_observations) + + +def test_deterministic_on_policy_sampling_dm_env(): + max_episode_length = 10 + env1 = DMControlEnv.from_suite('cartpole', 'balance') + env2 = DMControlEnv.from_suite('cartpole', 'balance') + # Fix the action sequence + env1.action_space.seed(10) + env2.action_space.seed(10) + policy1 = FixedPolicy(env1.spec, + scripted_actions=[ + env1.action_space.sample() + for _ in range(max_episode_length) + ]) + policy2 = FixedPolicy(env2.spec, + scripted_actions=[ + env2.action_space.sample() + for _ in range(max_episode_length) + ]) + n_workers = 2 + worker1 = WorkerFactory(seed=10, + max_episode_length=max_episode_length, + n_workers=n_workers) + worker2 = WorkerFactory(seed=10, + max_episode_length=max_episode_length, + n_workers=n_workers) + sampler1 = LocalSampler.from_worker_factory(worker1, policy1, env1) + sampler2 = LocalSampler.from_worker_factory(worker2, policy2, env2) + episodes1 = sampler1.obtain_samples(0, 10, policy1) + episodes2 = sampler2.obtain_samples(0, 10, policy2) + assert np.array_equal(episodes1.observations, episodes2.observations) + assert np.array_equal(episodes1.next_observations, + episodes2.next_observations) + + +if __name__ == '__main__': + test_deterministic_on_policy_sampling_dm_env() From 1eec91280b430dc0056f8c6af10f26897a7ae5f7 Mon Sep 17 00:00:00 2001 From: Adrian Ziegler Date: Thu, 10 Dec 2020 18:19:58 +0100 Subject: [PATCH 7/7] Fix: remove script stub from test --- tests/garage/sampler/test_local_sampler.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/garage/sampler/test_local_sampler.py b/tests/garage/sampler/test_local_sampler.py index 4494705b83..d5ed17d502 100644 --- a/tests/garage/sampler/test_local_sampler.py +++ b/tests/garage/sampler/test_local_sampler.py @@ -170,7 +170,3 @@ def test_deterministic_on_policy_sampling_dm_env(): assert np.array_equal(episodes1.observations, episodes2.observations) assert np.array_equal(episodes1.next_observations, episodes2.next_observations) - - -if __name__ == '__main__': - test_deterministic_on_policy_sampling_dm_env()