From dcd825c4340f8c31ae02ba072ad4398c96f486bb Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 3 Jul 2024 11:05:23 +0100 Subject: [PATCH 01/24] init --- .github/scripts/m1_script.sh | 2 +- .github/workflows/wheels.yml | 4 ++-- setup.py | 2 +- torchrl/collectors/collectors.py | 12 ++++++------ torchrl/envs/common.py | 13 +------------ torchrl/envs/transforms/transforms.py | 9 +-------- torchrl/modules/distributions/continuous.py | 5 +++-- torchrl/modules/tensordict_module/actors.py | 15 ++++++--------- version.txt | 2 +- 9 files changed, 22 insertions(+), 42 deletions(-) diff --git a/.github/scripts/m1_script.sh b/.github/scripts/m1_script.sh index 6552d8e4622..6da1cad5d79 100644 --- a/.github/scripts/m1_script.sh +++ b/.github/scripts/m1_script.sh @@ -1,5 +1,5 @@ #!/bin/bash -export TORCHRL_BUILD_VERSION=0.4.0 +export TORCHRL_BUILD_VERSION=0.5.0 ${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 9b2e57db531..7f89ef08635 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -32,7 +32,7 @@ jobs: run: | export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH" python3 -mpip install wheel - TORCHRL_BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel + TORCHRL_BUILD_VERSION=0.5.0 python3 setup.py bdist_wheel # NB: wheels have the linux_x86_64 tag so we rename to manylinux1 # find . -name 'dist/*whl' -exec bash -c ' mv $0 ${0/linux/manylinux1}' {} \; # pytorch/pytorch binaries are also manylinux_2_17 compliant but they @@ -72,7 +72,7 @@ jobs: shell: bash run: | python3 -mpip install wheel - TORCHRL_BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel + TORCHRL_BUILD_VERSION=0.5.0 python3 setup.py bdist_wheel - name: Upload wheel for the test-wheel job uses: actions/upload-artifact@v2 with: diff --git a/setup.py b/setup.py index 0196cb4a8f4..95dc0802a4f 100644 --- a/setup.py +++ b/setup.py @@ -172,7 +172,7 @@ def _main(argv): if is_nightly: tensordict_dep = "tensordict-nightly" else: - tensordict_dep = "tensordict>=0.4.0" + tensordict_dep = "tensordict>=0.5.0" if is_nightly: version = get_nightly_version() diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 50e3dd5cc49..32294a25edd 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -2065,18 +2065,18 @@ def _queue_len(self) -> int: def iterator(self) -> Iterator[TensorDictBase]: cat_results = self.cat_results if cat_results is None: - cat_results = 0 + cat_results = "stack" warnings.warn( f"`cat_results` was not specified in the constructor of {type(self).__name__}. " f"For MultiSyncDataCollector, `cat_results` indicates how the data should " - f"be packed: the preferred option is `cat_results='stack'` which provides " - f"the best interoperability across torchrl components. " + f"be packed: the preferred option and current default is `cat_results='stack'` " + f"which provides the best interoperability across torchrl components. " f"Other accepted values are `cat_results=0` (previous behaviour) and " f"`cat_results=-1` (cat along time dimension). Among these two, the latter " f"should be preferred for consistency across environment configurations. " - f"Currently, the default value is `0` (using torch.cat along first dimension)." - f"From v0.5 onward, this will default to `'stack'`. " - f"To suppress this warning, set stack_results to the desired value.", + f"Currently, the default value is `'stack'`." + f"From v0.6 onward, this warning will be removed. " + f"To suppress this warning, set `cat_results` to the desired value.", category=DeprecationWarning, ) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index c965e7dedf3..e30de3534d9 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -15,7 +15,6 @@ import torch import torch.nn as nn from tensordict import LazyStackedTensorDict, TensorDictBase, unravel_key -from tensordict.base import NO_DEFAULT from tensordict.utils import NestedKey from torchrl._utils import ( _ends_with, @@ -3020,21 +3019,11 @@ class _EnvWrapper(EnvBase): def __init__( self, *args, - device: DEVICE_TYPING = NO_DEFAULT, + device: DEVICE_TYPING = None, batch_size: Optional[torch.Size] = None, allow_done_after_reset: bool = False, **kwargs, ): - if device is NO_DEFAULT: - warnings.warn( - "Your wrapper was not given a device. Currently, this " - "value will default to 'cpu'. From v0.5 it will " - "default to `None`. With a device of None, no device casting " - "is performed and the resulting tensordicts are deviceless. " - "Please set your device accordingly.", - category=DeprecationWarning, - ) - device = torch.device("cpu") super().__init__( device=device, batch_size=batch_size, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index bec76c603e6..5900326c3ca 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -3411,14 +3411,7 @@ def __init__( out_keys_inv: Sequence[NestedKey] | None = None, ): if in_keys is not None and in_keys_inv is None: - warnings.warn( - "in_keys have been provided but not in_keys_inv. From v0.5, " - "this will result in in_keys_inv being an empty list whereas " - "now the input keys are retrieved automatically. " - "To silence this warning, pass the (possibly empty) " - "list of in_keys_inv.", - category=DeprecationWarning, - ) + in_keys_inv = [] self.dtype_in = dtype_in self.dtype_out = dtype_out diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 087cabe4186..38d8d1dfd02 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -481,9 +481,10 @@ def root_dist(self): @property def mode(self): warnings.warn( - "This computation of the mode is based on the first-order Taylor expansion " - "of the transform around the normal mean value, which can be inaccurate. " + "This computation of the mode is based on an inaccurate estimation of the mode " + "given the base_dist mode. " "To use a more stable implementation of the mode, use dist.get_mode() method instead. " + "To silence this warning, consider using the DETERMINISTIC exploration_type." "This implementation will be removed in v0.6.", category=DeprecationWarning, ) diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 17b1ea77ee4..83b6a8d1fb3 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -import warnings from typing import Dict, List, Optional, Sequence, Tuple, Union import torch @@ -922,10 +921,9 @@ def __init__( out_keys: Optional[Sequence[NestedKey]] = None, ): if isinstance(action_space, TensorSpec): - warnings.warn( - "Using specs in action_space will be deprecated in v0.4.0," - " please use the 'spec' argument if you want to provide an action spec", - category=DeprecationWarning, + raise RuntimeError( + "Using specs in action_space is deprecated. " + "Please use the 'spec' argument if you want to provide an action spec" ) action_space, _ = _process_action_space_spec(action_space, None) @@ -1136,10 +1134,9 @@ def __init__( action_mask_key: Optional[NestedKey] = None, ): if isinstance(action_space, TensorSpec): - warnings.warn( - "Using specs in action_space will be deprecated v0.4.0," - " please use the 'spec' argument if you want to provide an action spec", - category=DeprecationWarning, + raise RuntimeError( + "Using specs in action_space is deprecated." + "Please use the 'spec' argument if you want to provide an action spec" ) action_space, spec = _process_action_space_spec(action_space, spec) diff --git a/version.txt b/version.txt index 1d0ba9ea182..8f0916f768f 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.4.0 +0.5.0 From 14bcb80ead57dc47c7a2f4ecce79a1e05fd0b82c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 5 Jul 2024 17:43:24 +0100 Subject: [PATCH 02/24] empty From 724c41ff77a54bcf71b3070d0d2f327b373d7f29 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 09:53:07 +0100 Subject: [PATCH 03/24] amend --- .github/scripts/{m1_script.sh => td_script.sh} | 0 .github/workflows/build-wheels-linux.yml | 1 + .github/workflows/build-wheels-m1.yml | 2 +- .github/workflows/build-wheels-windows.yml | 1 + .github/workflows/wheels.yml | 0 5 files changed, 3 insertions(+), 1 deletion(-) rename .github/scripts/{m1_script.sh => td_script.sh} (100%) delete mode 100644 .github/workflows/wheels.yml diff --git a/.github/scripts/m1_script.sh b/.github/scripts/td_script.sh similarity index 100% rename from .github/scripts/m1_script.sh rename to .github/scripts/td_script.sh diff --git a/.github/workflows/build-wheels-linux.yml b/.github/workflows/build-wheels-linux.yml index 5171a7c3e2a..f51c5ed79b6 100644 --- a/.github/workflows/build-wheels-linux.yml +++ b/.github/workflows/build-wheels-linux.yml @@ -45,3 +45,4 @@ jobs: package-name: ${{ matrix.package-name }} smoke-test-script: ${{ matrix.smoke-test-script }} trigger-event: ${{ github.event_name }} + env-var-script: .github/scripts/td_script.sh diff --git a/.github/workflows/build-wheels-m1.yml b/.github/workflows/build-wheels-m1.yml index 84fe79d09d2..73a365a79f2 100644 --- a/.github/workflows/build-wheels-m1.yml +++ b/.github/workflows/build-wheels-m1.yml @@ -46,4 +46,4 @@ jobs: runner-type: macos-m1-stable smoke-test-script: ${{ matrix.smoke-test-script }} trigger-event: ${{ github.event_name }} - env-var-script: .github/scripts/m1_script.sh + env-var-script: .github/scripts/td_script.sh diff --git a/.github/workflows/build-wheels-windows.yml b/.github/workflows/build-wheels-windows.yml index 683f2a93f69..1beef7318f4 100644 --- a/.github/workflows/build-wheels-windows.yml +++ b/.github/workflows/build-wheels-windows.yml @@ -46,3 +46,4 @@ jobs: package-name: ${{ matrix.package-name }} smoke-test-script: ${{ matrix.smoke-test-script }} trigger-event: ${{ github.event_name }} + env-var-script: .github/scripts/td_script.sh diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml deleted file mode 100644 index e69de29bb2d..00000000000 From 00bba31ddf0d79e2888d67e634c0883cc86b0f20 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 10:38:52 +0100 Subject: [PATCH 04/24] amend --- .../decision_transformer/dt.py | 4 +++- .../decision_transformer/utils.py | 22 +++++++++++++------ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/sota-implementations/decision_transformer/dt.py b/sota-implementations/decision_transformer/dt.py index 59dbcafd8c9..dcb074b77fe 100644 --- a/sota-implementations/decision_transformer/dt.py +++ b/sota-implementations/decision_transformer/dt.py @@ -56,7 +56,9 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create test environment - test_env = make_env(cfg.env, obs_loc, obs_std, from_pixels=cfg.logger.video) + test_env = make_env( + cfg.env, obs_loc, obs_std, from_pixels=cfg.logger.video, device=model_device + ) if cfg.logger.video: test_env = test_env.append_transform( VideoRecorder(logger, tag="rendered", in_keys=["pixels"]) diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index 7c9500aa4e7..47b782e09ce 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -134,7 +134,9 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False): return transformed_env -def make_parallel_env(env_cfg, obs_loc, obs_std, train=False, from_pixels=False): +def make_parallel_env( + env_cfg, obs_loc, obs_std, train=False, from_pixels=False, device=None +): if train: num_envs = env_cfg.num_train_envs else: @@ -142,10 +144,12 @@ def make_parallel_env(env_cfg, obs_loc, obs_std, train=False, from_pixels=False) def make_env(): with set_gym_backend(env_cfg.backend): - return make_base_env(env_cfg, from_pixels=from_pixels) + return make_base_env(env_cfg, from_pixels=from_pixels, device="cpu") env = make_transformed_env( - ParallelEnv(num_envs, EnvCreator(make_env), serial_for_single=True), + ParallelEnv( + num_envs, EnvCreator(make_env), serial_for_single=True, device=device + ), env_cfg, obs_loc, obs_std, @@ -154,11 +158,15 @@ def make_env(): return env -def make_env(env_cfg, obs_loc, obs_std, train=False, from_pixels=False): - env = make_parallel_env( - env_cfg, obs_loc, obs_std, train=train, from_pixels=from_pixels +def make_env(env_cfg, obs_loc, obs_std, train=False, from_pixels=False, device=None): + return make_parallel_env( + env_cfg, + obs_loc, + obs_std, + train=train, + from_pixels=from_pixels, + device=device, ) - return env # ==================================================================== From 4937046d9d3753ba2f7288530fffbc2764a2aad6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 10:41:18 +0100 Subject: [PATCH 05/24] amend --- sota-implementations/decision_transformer/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index 47b782e09ce..409833c75fa 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -57,7 +57,7 @@ # ----------------- -def make_base_env(env_cfg, from_pixels=False): +def make_base_env(env_cfg, from_pixels=False, device=None): set_gym_backend(env_cfg.backend).set() env_library = LIBS[env_cfg.library] @@ -73,7 +73,7 @@ def make_base_env(env_cfg, from_pixels=False): if env_library is DMControlEnv: env_task = env_cfg.task env_kwargs.update({"task_name": env_task}) - env = env_library(**env_kwargs) + env = env_library(**env_kwargs, device=device) return env From 4cb6162c07ca9f0917fc66dd5aa165c5488d7f9f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 11:04:06 +0100 Subject: [PATCH 06/24] empty From eab40f97e3a7500395cf860bdf668b8b305a0a3a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 12:08:26 +0100 Subject: [PATCH 07/24] amend --- test/test_libs.py | 4 ++-- torchrl/data/tensor_specs.py | 2 ++ torchrl/envs/gym_like.py | 6 ++---- torchrl/envs/libs/gym.py | 5 ++--- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 42138b4ad9b..64f757659b1 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -1101,7 +1101,7 @@ def test_gym_gymnasium_parallel(self, maybe_fork_ParallelEnv): def test_vecenvs_nan(self): # noqa: F811 # old versions of gym must return nan for next values when there is a done state torch.manual_seed(0) - env = GymEnv("CartPole-v0", num_envs=2) + env = GymEnv("CartPole-v0", num_envs=2, device="cpu") env.set_seed(0) rollout = env.rollout(200) assert torch.isfinite(rollout.get("observation")).all() @@ -1110,7 +1110,7 @@ def test_vecenvs_nan(self): # noqa: F811 del env # same with collector - env = GymEnv("CartPole-v0", num_envs=2) + env = GymEnv("CartPole-v0", num_envs=2) # , device="cpu") env.set_seed(0) c = SyncDataCollector( env, RandomPolicy(env.action_spec), total_frames=2000, frames_per_batch=200 diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 04c24cb8d57..157c3e0ec9f 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -1143,6 +1143,7 @@ def __eq__(self, other): if not isinstance(other, LazyStackedTensorSpec): return False if self.device != other.device: + raise RuntimeError((self, other)) return False if len(self._specs) != len(other._specs): return False @@ -4778,6 +4779,7 @@ def _stack_specs(list_of_spec, dim, out=None): dim += len(shape) + 1 shape.insert(dim, len(list_of_spec)) return spec0.clone().unsqueeze(dim).expand(shape) + raise RuntimeError(list_of_spec) return LazyStackedTensorSpec(*list_of_spec, dim=dim) else: raise NotImplementedError diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 47f93f09779..c7935272c91 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -348,8 +348,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase: batch_size=tensordict.batch_size, ) if self.device is not None: - tensordict_out = tensordict_out.to(self.device, non_blocking=True) - self._sync_device() + tensordict_out = tensordict_out.to(self.device) if self.info_dict_reader and (info_dict is not None): if not isinstance(info_dict, dict): @@ -393,8 +392,7 @@ def _reset( if key not in tensordict_out.keys(True, True): tensordict_out[key] = item.zero() if self.device is not None: - tensordict_out = tensordict_out.to(self.device, non_blocking=True) - self._sync_device() + tensordict_out = tensordict_out.to(self.device) return tensordict_out @abc.abstractmethod diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 07c48587c14..9195929e31d 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -27,7 +27,6 @@ BoundedTensorSpec, CompositeSpec, DiscreteTensorSpec, - LazyStackedTensorSpec, MultiDiscreteTensorSpec, MultiOneHotDiscreteTensorSpec, OneHotDiscreteTensorSpec, @@ -246,8 +245,8 @@ def _gym_to_torchrl_spec_transform( ).expand(batch_size) gym_spaces = gym_backend("spaces") if isinstance(spec, gym_spaces.tuple.Tuple): - result = LazyStackedTensorSpec( - *[ + result = torch.stack( + [ _gym_to_torchrl_spec_transform( s, device=device, From 0334ca5f7b434a8ca0666116855bc06a6d41014d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 12:08:40 +0100 Subject: [PATCH 08/24] amend --- test/test_libs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 64f757659b1..42138b4ad9b 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -1101,7 +1101,7 @@ def test_gym_gymnasium_parallel(self, maybe_fork_ParallelEnv): def test_vecenvs_nan(self): # noqa: F811 # old versions of gym must return nan for next values when there is a done state torch.manual_seed(0) - env = GymEnv("CartPole-v0", num_envs=2, device="cpu") + env = GymEnv("CartPole-v0", num_envs=2) env.set_seed(0) rollout = env.rollout(200) assert torch.isfinite(rollout.get("observation")).all() @@ -1110,7 +1110,7 @@ def test_vecenvs_nan(self): # noqa: F811 del env # same with collector - env = GymEnv("CartPole-v0", num_envs=2) # , device="cpu") + env = GymEnv("CartPole-v0", num_envs=2) env.set_seed(0) c = SyncDataCollector( env, RandomPolicy(env.action_spec), total_frames=2000, frames_per_batch=200 From b077da1c8a0c4a23bfae61ddbd020ccf6e409b93 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 13:18:34 +0100 Subject: [PATCH 09/24] amend --- torchrl/data/tensor_specs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 157c3e0ec9f..0006213cd27 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -4779,7 +4779,6 @@ def _stack_specs(list_of_spec, dim, out=None): dim += len(shape) + 1 shape.insert(dim, len(list_of_spec)) return spec0.clone().unsqueeze(dim).expand(shape) - raise RuntimeError(list_of_spec) return LazyStackedTensorSpec(*list_of_spec, dim=dim) else: raise NotImplementedError From 312b82e891f5b74520dd32ffd3825280dae4ab8e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 13:59:46 +0100 Subject: [PATCH 10/24] amend --- sota-implementations/a2c/a2c_atari.py | 2 +- sota-implementations/a2c/a2c_mujoco.py | 2 +- sota-implementations/cql/cql_offline.py | 2 +- sota-implementations/cql/cql_online.py | 2 +- sota-implementations/cql/discrete_cql_online.py | 2 +- sota-implementations/ddpg/ddpg.py | 2 +- sota-implementations/decision_transformer/dt.py | 2 +- sota-implementations/decision_transformer/online_dt.py | 2 +- sota-implementations/discrete_sac/discrete_sac.py | 2 +- sota-implementations/dqn/dqn_atari.py | 2 +- sota-implementations/dqn/dqn_cartpole.py | 2 +- sota-implementations/dreamer/dreamer.py | 4 ++-- sota-implementations/dreamer/dreamer_utils.py | 2 +- sota-implementations/impala/impala_multi_node_ray.py | 2 +- sota-implementations/impala/impala_multi_node_submitit.py | 2 +- sota-implementations/impala/impala_single_node.py | 2 +- sota-implementations/iql/discrete_iql.py | 2 +- sota-implementations/iql/iql_offline.py | 2 +- sota-implementations/iql/iql_online.py | 2 +- sota-implementations/multiagent/iql.py | 2 +- sota-implementations/multiagent/maddpg_iddpg.py | 2 +- sota-implementations/multiagent/mappo_ippo.py | 2 +- sota-implementations/multiagent/qmix_vdn.py | 2 +- sota-implementations/multiagent/sac.py | 2 +- sota-implementations/ppo/ppo_atari.py | 2 +- sota-implementations/ppo/ppo_mujoco.py | 2 +- sota-implementations/sac/sac.py | 2 +- sota-implementations/td3/td3.py | 2 +- 28 files changed, 29 insertions(+), 29 deletions(-) diff --git a/sota-implementations/a2c/a2c_atari.py b/sota-implementations/a2c/a2c_atari.py index 775dcfe206d..f8c18147306 100644 --- a/sota-implementations/a2c/a2c_atari.py +++ b/sota-implementations/a2c/a2c_atari.py @@ -201,7 +201,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: diff --git a/sota-implementations/a2c/a2c_mujoco.py b/sota-implementations/a2c/a2c_mujoco.py index 0276039058f..d115174eb9c 100644 --- a/sota-implementations/a2c/a2c_mujoco.py +++ b/sota-implementations/a2c/a2c_mujoco.py @@ -186,7 +186,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): prev_test_frame = ((i - 1) * frames_in_batch) // cfg.logger.test_interval cur_test_frame = (i * frames_in_batch) // cfg.logger.test_interval final = collected_frames >= collector.total_frames diff --git a/sota-implementations/cql/cql_offline.py b/sota-implementations/cql/cql_offline.py index d8185c8091c..5ca70f83b53 100644 --- a/sota-implementations/cql/cql_offline.py +++ b/sota-implementations/cql/cql_offline.py @@ -150,7 +150,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # evaluation if i % evaluation_interval == 0: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_td = eval_env.rollout( max_steps=eval_steps, policy=model[0], auto_cast_to_device=True ) diff --git a/sota-implementations/cql/cql_online.py b/sota-implementations/cql/cql_online.py index 5f8f81357c8..cf629ed0733 100644 --- a/sota-implementations/cql/cql_online.py +++ b/sota-implementations/cql/cql_online.py @@ -204,7 +204,7 @@ def main(cfg: "DictConfig"): # noqa: F821 cur_test_frame = (i * frames_per_batch) // evaluation_interval final = current_frames >= collector.total_frames if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/cql/discrete_cql_online.py b/sota-implementations/cql/discrete_cql_online.py index 4b6f14cd058..d0d6693eb97 100644 --- a/sota-implementations/cql/discrete_cql_online.py +++ b/sota-implementations/cql/discrete_cql_online.py @@ -183,7 +183,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/ddpg/ddpg.py b/sota-implementations/ddpg/ddpg.py index eb0b88c26f7..a92ee6185c3 100644 --- a/sota-implementations/ddpg/ddpg.py +++ b/sota-implementations/ddpg/ddpg.py @@ -185,7 +185,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/decision_transformer/dt.py b/sota-implementations/decision_transformer/dt.py index dcb074b77fe..9cca9fd8af5 100644 --- a/sota-implementations/decision_transformer/dt.py +++ b/sota-implementations/decision_transformer/dt.py @@ -116,7 +116,7 @@ def main(cfg: "DictConfig"): # noqa: F821 to_log = {"train/loss": loss_vals["loss"]} # Evaluation - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): if i % pretrain_log_interval == 0: eval_td = test_env.rollout( max_steps=eval_steps, diff --git a/sota-implementations/decision_transformer/online_dt.py b/sota-implementations/decision_transformer/online_dt.py index 5cb297e5c0b..da2241ce9fa 100644 --- a/sota-implementations/decision_transformer/online_dt.py +++ b/sota-implementations/decision_transformer/online_dt.py @@ -126,7 +126,7 @@ def main(cfg: "DictConfig"): # noqa: F821 } # Evaluation - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): inference_policy.eval() if i % pretrain_log_interval == 0: eval_td = test_env.rollout( diff --git a/sota-implementations/discrete_sac/discrete_sac.py b/sota-implementations/discrete_sac/discrete_sac.py index 6e100f92dc3..386f743c7d3 100644 --- a/sota-implementations/discrete_sac/discrete_sac.py +++ b/sota-implementations/discrete_sac/discrete_sac.py @@ -204,7 +204,7 @@ def main(cfg: "DictConfig"): # noqa: F821 cur_test_frame = (i * frames_per_batch) // eval_iter final = current_frames >= collector.total_frames if (i >= 1 and (prev_test_frame < cur_test_frame)) or final: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/dqn/dqn_atari.py b/sota-implementations/dqn/dqn_atari.py index 90f93551d4d..906273ee2f5 100644 --- a/sota-implementations/dqn/dqn_atari.py +++ b/sota-implementations/dqn/dqn_atari.py @@ -199,7 +199,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get and log evaluation rewards and eval time - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): prev_test_frame = ((i - 1) * frames_per_batch) // test_interval cur_test_frame = (i * frames_per_batch) // test_interval final = current_frames >= collector.total_frames diff --git a/sota-implementations/dqn/dqn_cartpole.py b/sota-implementations/dqn/dqn_cartpole.py index ac3f17a9203..173f88f7028 100644 --- a/sota-implementations/dqn/dqn_cartpole.py +++ b/sota-implementations/dqn/dqn_cartpole.py @@ -180,7 +180,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get and log evaluation rewards and eval time - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): prev_test_frame = ((i - 1) * frames_per_batch) // test_interval cur_test_frame = (i * frames_per_batch) // test_interval final = current_frames >= collector.total_frames diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index e7b346b2b22..af8d3334950 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -284,7 +284,7 @@ def compile_rssms(module): # Evaluation if (i % eval_iter) == 0: # Real env - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_rollout = test_env.rollout( eval_rollout_steps, policy, @@ -298,7 +298,7 @@ def compile_rssms(module): log_metrics(logger, eval_metrics, collected_frames) # Simulated env if model_based_env_eval is not None: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_rollout = model_based_env_eval.rollout( eval_rollout_steps, policy, diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index ff14871b011..59a17ff8648 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -535,7 +535,7 @@ def _dreamer_make_actor_real( SafeProbabilisticModule( in_keys=["loc", "scale"], out_keys=[action_key], - default_interaction_type=InteractionType.MODE, + default_interaction_type=InteractionType.DETERMINISTIC, distribution_class=TanhNormal, distribution_kwargs={"tanh_loc": True}, spec=CompositeSpec( diff --git a/sota-implementations/impala/impala_multi_node_ray.py b/sota-implementations/impala/impala_multi_node_ray.py index 0482a595ffa..1998c044305 100644 --- a/sota-implementations/impala/impala_multi_node_ray.py +++ b/sota-implementations/impala/impala_multi_node_ray.py @@ -247,7 +247,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: diff --git a/sota-implementations/impala/impala_multi_node_submitit.py b/sota-implementations/impala/impala_multi_node_submitit.py index ce96cf06ce8..fdee4256c42 100644 --- a/sota-implementations/impala/impala_multi_node_submitit.py +++ b/sota-implementations/impala/impala_multi_node_submitit.py @@ -239,7 +239,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: diff --git a/sota-implementations/impala/impala_single_node.py b/sota-implementations/impala/impala_single_node.py index bb0f314197a..cf583909620 100644 --- a/sota-implementations/impala/impala_single_node.py +++ b/sota-implementations/impala/impala_single_node.py @@ -217,7 +217,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index 33513dd3973..ae1894379fd 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -186,7 +186,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py index d98724e1371..d1a16fd8192 100644 --- a/sota-implementations/iql/iql_offline.py +++ b/sota-implementations/iql/iql_offline.py @@ -130,7 +130,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # evaluation if i % evaluation_interval == 0: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_td = eval_env.rollout( max_steps=eval_steps, policy=model[0], auto_cast_to_device=True ) diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index b66c6f9dcf2..d50ff806294 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -184,7 +184,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/multiagent/iql.py b/sota-implementations/multiagent/iql.py index 81551ebefb7..a4d2b88a9d0 100644 --- a/sota-implementations/multiagent/iql.py +++ b/sota-implementations/multiagent/iql.py @@ -206,7 +206,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/sota-implementations/multiagent/maddpg_iddpg.py b/sota-implementations/multiagent/maddpg_iddpg.py index 9d14ff04b04..bd44bb0a043 100644 --- a/sota-implementations/multiagent/maddpg_iddpg.py +++ b/sota-implementations/multiagent/maddpg_iddpg.py @@ -230,7 +230,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/sota-implementations/multiagent/mappo_ippo.py b/sota-implementations/multiagent/mappo_ippo.py index e752c4d73f2..fa006a7d4a2 100644 --- a/sota-implementations/multiagent/mappo_ippo.py +++ b/sota-implementations/multiagent/mappo_ippo.py @@ -236,7 +236,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/sota-implementations/multiagent/qmix_vdn.py b/sota-implementations/multiagent/qmix_vdn.py index d294a9c783e..4e6a962c556 100644 --- a/sota-implementations/multiagent/qmix_vdn.py +++ b/sota-implementations/multiagent/qmix_vdn.py @@ -241,7 +241,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/sota-implementations/multiagent/sac.py b/sota-implementations/multiagent/sac.py index 30b7e7e98bc..f7b2523010b 100644 --- a/sota-implementations/multiagent/sac.py +++ b/sota-implementations/multiagent/sac.py @@ -300,7 +300,7 @@ def train(cfg: "DictConfig"): # noqa: F821 and cfg.logger.backend ): evaluation_start = time.time() - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): env_test.frames = [] rollouts = env_test.rollout( max_steps=cfg.env.max_steps, diff --git a/sota-implementations/ppo/ppo_atari.py b/sota-implementations/ppo/ppo_atari.py index 908cb7924a3..2b02254032a 100644 --- a/sota-implementations/ppo/ppo_atari.py +++ b/sota-implementations/ppo/ppo_atari.py @@ -217,7 +217,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch * frame_skip) // test_interval < ( i * frames_in_batch * frame_skip ) // test_interval: diff --git a/sota-implementations/ppo/ppo_mujoco.py b/sota-implementations/ppo/ppo_mujoco.py index e3e74971a49..219ae1b59b6 100644 --- a/sota-implementations/ppo/ppo_mujoco.py +++ b/sota-implementations/ppo/ppo_mujoco.py @@ -210,7 +210,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Get test rewards - with torch.no_grad(), set_exploration_type(ExplorationType.MODE): + with torch.no_grad(), set_exploration_type(ExplorationType.DETERMINISTIC): if ((i - 1) * frames_in_batch) // cfg_logger_test_interval < ( i * frames_in_batch ) // cfg_logger_test_interval: diff --git a/sota-implementations/sac/sac.py b/sota-implementations/sac/sac.py index f7a399cda72..9904fe072ab 100644 --- a/sota-implementations/sac/sac.py +++ b/sota-implementations/sac/sac.py @@ -197,7 +197,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, diff --git a/sota-implementations/td3/td3.py b/sota-implementations/td3/td3.py index 97fd039c238..5fbc9b032d7 100644 --- a/sota-implementations/td3/td3.py +++ b/sota-implementations/td3/td3.py @@ -195,7 +195,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): eval_start = time.time() eval_rollout = eval_env.rollout( eval_rollout_steps, From 07b908264024f70f36512c860b0e99086510faa6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 14:54:10 +0100 Subject: [PATCH 11/24] amend --- test/test_env.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_env.py b/test/test_env.py index e6ca38b729c..f8f242f3955 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -2061,6 +2061,7 @@ def main_collector(j, q=None): total_frames=N * n_workers * 100, storing_device=device, device=device, + cat_results=-1, ) single_collectors = [ SyncDataCollector( From 3610510c39f2150198ef2dd2d80d37c7ba9760a1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 14:56:25 +0100 Subject: [PATCH 12/24] amend --- sota-implementations/dreamer/dreamer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index af8d3334950..e36d87deaa8 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -298,7 +298,9 @@ def compile_rssms(module): log_metrics(logger, eval_metrics, collected_frames) # Simulated env if model_based_env_eval is not None: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(): eval_rollout = model_based_env_eval.rollout( eval_rollout_steps, policy, From 725e6df1b0941c52708a3adf80317e3d8516ecd1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 15:43:58 +0100 Subject: [PATCH 13/24] amend --- sota-implementations/dreamer/dreamer_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 59a17ff8648..354aa466b7f 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -88,6 +88,7 @@ def _make_env(cfg, device, from_pixels=False): cfg.env.task, from_pixels=cfg.env.from_pixels or from_pixels, pixels_only=cfg.env.from_pixels, + device=device, ) else: raise NotImplementedError(f"Unknown lib {lib}.") @@ -98,7 +99,6 @@ def _make_env(cfg, device, from_pixels=False): env = env.append_transform( TensorDictPrimer(random=False, default_value=0, **default_dict) ) - assert env is not None return env From bd67b9f7e92581362480f92bc7da091d69211a96 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 15:51:45 +0100 Subject: [PATCH 14/24] amend --- sota-implementations/dreamer/config.yaml | 2 -- sota-implementations/dreamer/dreamer.py | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index ab101e8486a..6d719ed215c 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -18,8 +18,6 @@ collector: init_random_frames: 3000 frames_per_batch: 1000 device: - _target_: dreamer_utils._default_device - device: null optimization: train_every: 1000 diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index e36d87deaa8..e402c4fde1a 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -15,7 +15,7 @@ make_collector, make_dreamer, make_environments, - make_replay_buffer, + make_replay_buffer, _default_device, ) from hydra.utils import instantiate @@ -38,7 +38,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # cfg = correct_for_frame_skip(cfg) - device = torch.device(instantiate(cfg.networks.device)) + device = _default_device(cfg.networks.device) # Create logger exp_name = generate_exp_name("Dreamer", cfg.logger.exp_name) From 0cf1c1aec26fae8bb14d8f95b0fe6524eb71e9b4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 15:53:28 +0100 Subject: [PATCH 15/24] amend --- sota-implementations/dreamer/config.yaml | 4 ---- sota-implementations/dreamer/dreamer_utils.py | 6 +++--- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index 6d719ed215c..12e7c4e6446 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -10,8 +10,6 @@ env: horizon: 500 n_parallel_envs: 8 device: - _target_: dreamer_utils._default_device - device: null collector: total_frames: 5_000_000 @@ -39,8 +37,6 @@ optimization: networks: exploration_noise: 0.3 device: - _target_: dreamer_utils._default_device - device: null state_dim: 30 rssm_hidden_dim: 200 hidden_dim: 400 diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 354aa466b7f..520e846a583 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -129,7 +129,7 @@ def transform_env(cfg, env): def make_environments(cfg, parallel_envs=1, logger=None): """Make environments for training and evaluation.""" - func = functools.partial(_make_env, cfg=cfg, device=cfg.env.device) + func = functools.partial(_make_env, cfg=cfg, device=_default_device(cfg.env.device)) train_env = ParallelEnv( parallel_envs, EnvCreator(func), @@ -138,7 +138,7 @@ def make_environments(cfg, parallel_envs=1, logger=None): train_env = transform_env(cfg, train_env) train_env.set_seed(cfg.env.seed) func = functools.partial( - _make_env, cfg=cfg, device=cfg.env.device, from_pixels=cfg.logger.video + _make_env, cfg=cfg, device=_default_device(cfg.env.device), from_pixels=cfg.logger.video ) eval_env = ParallelEnv( 1, @@ -332,7 +332,7 @@ def make_collector(cfg, train_env, actor_model_explore): init_random_frames=cfg.collector.init_random_frames, frames_per_batch=cfg.collector.frames_per_batch, total_frames=cfg.collector.total_frames, - policy_device=instantiate(cfg.collector.device), + policy_device=_default_device(cfg.collector.device), env_device=train_env.device, storing_device="cpu", ) From 1aec5248df374fff4050db43b82a0ac016bc12b4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 16:14:23 +0100 Subject: [PATCH 16/24] amend --- torchrl/envs/libs/vmas.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index 8e30fdb2a7e..9751e84a3ac 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -795,7 +795,9 @@ def _build_env( env=vmas.make_env( scenario=scenario, num_envs=num_envs, - device=self.device, + device=self.device + if self.device is not None + else torch.get_default_device(), continuous_actions=continuous_actions, max_steps=max_steps, seed=seed, From 0704caa75e7e2e4fd0d5c82da6017c4ad9c3e5c8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 16:16:41 +0100 Subject: [PATCH 17/24] amend --- sota-implementations/dreamer/dreamer.py | 4 ++-- sota-implementations/dreamer/dreamer_utils.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index e402c4fde1a..e521b9df386 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -10,14 +10,14 @@ import torch.cuda import tqdm from dreamer_utils import ( + _default_device, dump_video, log_metrics, make_collector, make_dreamer, make_environments, - make_replay_buffer, _default_device, + make_replay_buffer, ) -from hydra.utils import instantiate # mixed precision training from torch.cuda.amp import GradScaler diff --git a/sota-implementations/dreamer/dreamer_utils.py b/sota-implementations/dreamer/dreamer_utils.py index 520e846a583..73baa310821 100644 --- a/sota-implementations/dreamer/dreamer_utils.py +++ b/sota-implementations/dreamer/dreamer_utils.py @@ -9,7 +9,6 @@ import torch import torch.nn as nn -from hydra.utils import instantiate from tensordict import NestedKey from tensordict.nn import ( InteractionType, @@ -138,7 +137,10 @@ def make_environments(cfg, parallel_envs=1, logger=None): train_env = transform_env(cfg, train_env) train_env.set_seed(cfg.env.seed) func = functools.partial( - _make_env, cfg=cfg, device=_default_device(cfg.env.device), from_pixels=cfg.logger.video + _make_env, + cfg=cfg, + device=_default_device(cfg.env.device), + from_pixels=cfg.logger.video, ) eval_env = ParallelEnv( 1, From 1ea60c79d94182006091eb19af4a827c95e25f22 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 9 Jul 2024 08:40:16 -0700 Subject: [PATCH 18/24] amend --- torchrl/envs/batched_envs.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 7f462782757..2fb70484777 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -406,17 +406,18 @@ def _find_sync_values(self): return _do_nothing, _do_nothing if worker_device is None: - worker_not_main = [False] + worker_not_main = False - def find_all_worker_devices(item, worker_not_main=worker_not_main): + def find_all_worker_devices(item): + nonlocal worker_not_main if hasattr(item, "device"): - worker_not_main[0] = worker_not_main[0] or ( + worker_not_main = worker_not_main or ( item.device != self_device ) for td in self.shared_tensordicts: td.apply(find_all_worker_devices, filter_empty=True) - if worker_not_main[0]: + if worker_not_main: if torch.cuda.is_available(): worker_device = ( torch.device("cuda") @@ -431,6 +432,8 @@ def find_all_worker_devices(item, worker_not_main=worker_not_main): ) else: raise RuntimeError("Did not find a valid worker device") + else: + worker_device = self_device if ( worker_device is not None @@ -460,6 +463,7 @@ def find_all_worker_devices(item, worker_not_main=worker_not_main): and self_device.type == "mps" ): return _mps_sync(self_device), _mps_sync(self_device) + return _do_nothing, _do_nothing def __getstate__(self): out = copy(self.__dict__) From 52336a107554a91b0dd0da0e4ac0d945b3834432 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 9 Jul 2024 08:56:12 -0700 Subject: [PATCH 19/24] amend --- torchrl/envs/batched_envs.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 2fb70484777..4df0546fc88 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -411,9 +411,7 @@ def _find_sync_values(self): def find_all_worker_devices(item): nonlocal worker_not_main if hasattr(item, "device"): - worker_not_main = worker_not_main or ( - item.device != self_device - ) + worker_not_main = worker_not_main or (item.device != self_device) for td in self.shared_tensordicts: td.apply(find_all_worker_devices, filter_empty=True) From de73d67a2a725c5ef96d7eb420a7812005be71df Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 18:29:25 +0100 Subject: [PATCH 20/24] amend --- torchrl/record/recorder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index 2c2f3fb21ac..b7fb8ab4ed2 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -221,11 +221,11 @@ def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: observation_trsf = make_grid( obs_flat, nrow=int(math.ceil(math.sqrt(obs_flat.shape[0]))) ) - self.obs.append(observation_trsf.to(torch.uint8)) + self.obs.append(observation_trsf.to("cpu", torch.uint8)) elif observation_trsf.ndimension() >= 4: - self.obs.extend(observation_trsf.to(torch.uint8).flatten(0, -4)) + self.obs.extend(observation_trsf.to("cpu", torch.uint8).flatten(0, -4)) else: - self.obs.append(observation_trsf.to(torch.uint8)) + self.obs.append(observation_trsf.to("cpu", torch.uint8)) return observation def forward(self, tensordict: TensorDictBase) -> TensorDictBase: From 9e481a65da56d57b90678205c6f0e359a3f546f2 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 20:50:24 +0100 Subject: [PATCH 21/24] amend --- .github/unittest/linux_examples/scripts/run_test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 075489b208d..c65e6124efe 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -24,6 +24,7 @@ lib_dir="${env_dir}/lib" # solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir export MKL_THREADING_LAYER=GNU +export CUDA_LAUNCH_BLOCKING=1 python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test.py -v --durations 200 #python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/smoke_test_deps.py -v --durations 200 From d9a71780d475fd84c1d4afb06d8b3a2ac0d8d064 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Jul 2024 20:53:12 +0100 Subject: [PATCH 22/24] amend --- .../linux_examples/scripts/run_test.sh | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index c65e6124efe..af1d4952479 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -9,9 +9,19 @@ # # -set -e +#set -e set -v +# Initialize an error flag +error_occurred=0 +# Function to handle errors +error_handler() { + echo "Error on line $1" + error_occurred=1 +} +# Trap ERR to call the error_handler function with the failing line number +trap 'error_handler $LINENO' ERR + export PYTORCH_TEST_WITH_SLOW='1' python -m torch.utils.collect_env # Avoid error: "fatal: unsafe repository" @@ -299,3 +309,11 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/ba coverage combine coverage xml -i + +# Check if any errors occurred during the script execution +if [ "$error_occurred" -ne 0 ]; then + echo "Errors occurred during script execution" + exit 1 +else + echo "Script executed successfully" +fi From 4e0596c41c7686b62bd55ef7390ca7beeceb970d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 10 Jul 2024 11:52:49 +0100 Subject: [PATCH 23/24] amend --- sota-implementations/dreamer/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/dreamer/config.yaml b/sota-implementations/dreamer/config.yaml index 12e7c4e6446..604e1ac546a 100644 --- a/sota-implementations/dreamer/config.yaml +++ b/sota-implementations/dreamer/config.yaml @@ -9,7 +9,7 @@ env: image_size : 64 horizon: 500 n_parallel_envs: 8 - device: + device: cpu collector: total_frames: 5_000_000 From 4315f0c19add8660bec0949810ecb6336839bda1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 10 Jul 2024 12:06:41 +0100 Subject: [PATCH 24/24] amend --- .../unittest/linux_examples/scripts/run_test.sh | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 039ca3059cd..f8b700c0410 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -174,18 +174,6 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cr env.name=Pendulum-v1 \ network.device= \ logger.backend= -python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dreamer/dreamer.py \ - collector.total_frames=200 \ - collector.init_random_frames=10 \ - collector.frames_per_batch=200 \ - env.n_parallel_envs=4 \ - optimization.optim_steps_per_batch=1 \ - logger.video=True \ - logger.backend=csv \ - replay_buffer.buffer_size=120 \ - replay_buffer.batch_size=24 \ - replay_buffer.batch_length=12 \ - networks.rssm_hidden_dim=17 python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3/td3.py \ collector.total_frames=48 \ collector.init_random_frames=10 \ @@ -225,8 +213,8 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/dr collector.frames_per_batch=200 \ env.n_parallel_envs=1 \ optimization.optim_steps_per_batch=1 \ - logger.backend=csv \ logger.video=True \ + logger.backend=csv \ replay_buffer.buffer_size=120 \ replay_buffer.batch_size=24 \ replay_buffer.batch_length=12 \