diff --git a/.github/unittest/linux_sota/scripts/test_sota.py b/.github/unittest/linux_sota/scripts/test_sota.py index b7af381634c..25d1e7a4390 100644 --- a/.github/unittest/linux_sota/scripts/test_sota.py +++ b/.github/unittest/linux_sota/scripts/test_sota.py @@ -188,19 +188,6 @@ ppo.collector.frames_per_batch=16 \ logger.mode=offline \ logger.backend= -""", - "dreamer": """python sota-implementations/dreamer/dreamer.py \ - collector.total_frames=600 \ - collector.init_random_frames=10 \ - collector.frames_per_batch=200 \ - env.n_parallel_envs=1 \ - optimization.optim_steps_per_batch=1 \ - logger.video=False \ - logger.backend=csv \ - replay_buffer.buffer_size=120 \ - replay_buffer.batch_size=24 \ - replay_buffer.batch_length=12 \ - networks.rssm_hidden_dim=17 """, "ddpg-single": """python sota-implementations/ddpg/ddpg.py \ collector.total_frames=48 \ @@ -289,6 +276,19 @@ logger.backend= """, "bandits": """python sota-implementations/bandits/dqn.py --n_steps=100 +""", + "dreamer": """python sota-implementations/dreamer/dreamer.py \ + collector.total_frames=600 \ + collector.init_random_frames=10 \ + collector.frames_per_batch=200 \ + env.n_parallel_envs=1 \ + optimization.optim_steps_per_batch=1 \ + logger.video=False \ + logger.backend=csv \ + replay_buffer.buffer_size=120 \ + replay_buffer.batch_size=24 \ + replay_buffer.batch_length=12 \ + networks.rssm_hidden_dim=17 """, } diff --git a/test/test_cost.py b/test/test_cost.py index 61ff0517024..8bbb7edce05 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -18,7 +18,6 @@ import torch from packaging import version, version as pack_version - from tensordict import assert_allclose_td, TensorDict, TensorDictBase from tensordict._C import unravel_keys from tensordict.nn import ( @@ -37,6 +36,7 @@ TensorDictSequential as Seq, WrapModule, ) +from tensordict.nn.distributions.composite import _add_suffix from tensordict.nn.utils import Buffer from tensordict.utils import unravel_key from torch import autograd, nn @@ -199,6 +199,13 @@ def get_devices(): class LossModuleTestBase: + @pytest.fixture(scope="class", autouse=True) + def _composite_log_prob(self): + setter = set_composite_lp_aggregate(False) + setter.set() + yield + setter.unset() + def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) assert hasattr( @@ -3541,13 +3548,6 @@ def test_td3bc_reduction(self, reduction): class TestSAC(LossModuleTestBase): seed = 0 - @pytest.fixture(scope="class", autouse=True) - def _composite_log_prob(self): - setter = set_composite_lp_aggregate(False) - setter.set() - yield - setter.unset() - def _create_mock_actor( self, batch=2, @@ -4623,13 +4623,6 @@ def test_sac_reduction(self, reduction, version, composite_action_dist): class TestDiscreteSAC(LossModuleTestBase): seed = 0 - @pytest.fixture(scope="class", autouse=True) - def _composite_log_prob(self): - setter = set_composite_lp_aggregate(False) - setter.set() - yield - setter.unset() - def _create_mock_actor( self, batch=2, @@ -6786,7 +6779,7 @@ def test_redq_tensordict_keys(self, td_est): "priority": "td_error", "action": "action", "value": "state_value", - "sample_log_prob": "sample_log_prob", + "sample_log_prob": "action_log_prob", "state_action_value": "state_action_value", "reward": "reward", "done": "done", @@ -6849,12 +6842,22 @@ def test_redq_notensordict( actor_network=actor, qvalue_network=qvalue, ) - loss.set_keys( - action=action_key, - reward=reward_key, - done=done_key, - terminated=terminated_key, - ) + if deprec: + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + log_prob=_add_suffix(action_key, "_log_prob"), + ) + else: + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + sample_log_prob=_add_suffix(action_key, "_log_prob"), + ) kwargs = { action_key: td.get(action_key), @@ -7916,13 +7919,6 @@ def test_dcql_reduction(self, reduction): class TestPPO(LossModuleTestBase): seed = 0 - @pytest.fixture(scope="class", autouse=True) - def _composite_log_prob(self): - setter = set_composite_lp_aggregate(False) - setter.set() - yield - setter.unset() - def _create_mock_actor( self, batch=2, @@ -8003,7 +7999,7 @@ def _create_mock_actor_value( action_dim=4, device="cpu", composite_action_dist=False, - sample_log_prob_key="sample_log_prob", + sample_log_prob_key="action_log_prob", ): # Actor action_spec = Bounded( @@ -8058,7 +8054,7 @@ def _create_mock_actor_value_shared( action_dim=4, device="cpu", composite_action_dist=False, - sample_log_prob_key="sample_log_prob", + sample_log_prob_key="action_log_prob", ): # Actor action_spec = Bounded( @@ -8123,7 +8119,7 @@ def _create_mock_data_ppo( reward_key="reward", done_key="done", terminated_key="terminated", - sample_log_prob_key="sample_log_prob", + sample_log_prob_key="action_log_prob", composite_action_dist=False, ): # create a tensordict @@ -8834,7 +8830,7 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est): "advantage": "advantage_test", "value_target": "value_target_test", "value": "state_value_test", - "sample_log_prob": "sample_log_prob_test", + "sample_log_prob": "action_log_prob_test", "action": "action_test", } @@ -9242,13 +9238,6 @@ def mixture_constructor(logits, loc, scale): class TestA2C(LossModuleTestBase): seed = 0 - @pytest.fixture(scope="class", autouse=True) - def _composite_log_prob(self): - setter = set_composite_lp_aggregate(False) - setter.set() - yield - setter.unset() - def _create_mock_actor( self, batch=2, @@ -9814,7 +9803,7 @@ def test_a2c_tensordict_keys_run( value_key = "state_value_test" action_key = "action_test" reward_key = "reward_test" - sample_log_prob_key = "sample_log_prob_test" + sample_log_prob_key = "action_log_prob_test" done_key = ("done", "test") terminated_key = ("terminated", "test") @@ -10258,7 +10247,7 @@ def test_reinforce_tensordict_keys(self, td_est): "advantage": "advantage", "value_target": "value_target", "value": "state_value", - "sample_log_prob": "sample_log_prob", + "sample_log_prob": "action_log_prob", "reward": "reward", "done": "done", "terminated": "terminated", @@ -10316,7 +10305,7 @@ def _create_mock_common_layer_setup( { "obs": torch.randn(*batch, n_obs), "action": torch.randn(*batch, n_act), - "sample_log_prob": torch.randn(*batch), + "action_log_prob": torch.randn(*batch), "done": torch.zeros(*batch, 1, dtype=torch.bool), "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { @@ -11788,7 +11777,7 @@ def _create_mock_common_layer_setup( { "obs": torch.randn(*batch, n_obs), "action": torch.randn(*batch, n_act), - "sample_log_prob": torch.randn(*batch), + "action_log_prob": torch.randn(*batch), "done": torch.zeros(*batch, 1, dtype=torch.bool), "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { @@ -12604,7 +12593,7 @@ def _create_mock_common_layer_setup( { "obs": torch.randn(*batch, n_obs), "action": torch.randn(*batch, n_act), - "sample_log_prob": torch.randn(*batch), + "action_log_prob": torch.randn(*batch), "done": torch.zeros(*batch, 1, dtype=torch.bool), "terminated": torch.zeros(*batch, 1, dtype=torch.bool), "next": { @@ -15228,6 +15217,7 @@ def test_successive_traj_gae( ["half", torch.half, "cpu"], ], ) +@set_composite_lp_aggregate(False) def test_shared_params(dest, expected_dtype, expected_device): if torch.cuda.device_count() == 0 and dest == "cuda": pytest.skip("no cuda device available") @@ -15332,6 +15322,13 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: class TestAdv: + @pytest.fixture(scope="class", autouse=True) + def _composite_log_prob(self): + setter = set_composite_lp_aggregate(False) + setter.set() + yield + setter.unset() + @pytest.mark.parametrize( "adv,kwargs", [ @@ -15369,7 +15366,7 @@ def test_dispatch( ) kwargs = { "obs": torch.randn(1, 10, 3), - "sample_log_prob": torch.log(torch.rand(1, 10, 1)), + "action_log_prob": torch.log(torch.rand(1, 10, 1)), "next_reward": torch.randn(1, 10, 1, requires_grad=True), "next_done": torch.zeros(1, 10, 1, dtype=torch.bool), "next_terminated": torch.zeros(1, 10, 1, dtype=torch.bool), @@ -15431,7 +15428,7 @@ def test_diff_reward( td = TensorDict( { "obs": torch.randn(1, 10, 3), - "sample_log_prob": torch.log(torch.rand(1, 10, 1)), + "action_log_prob": torch.log(torch.rand(1, 10, 1)), "next": { "obs": torch.randn(1, 10, 3), "reward": torch.randn(1, 10, 1, requires_grad=True), @@ -15504,7 +15501,7 @@ def test_non_differentiable(self, adv, shifted, kwargs): td = TensorDict( { "obs": torch.randn(1, 10, 3), - "sample_log_prob": torch.log(torch.rand(1, 10, 1)), + "action_log_prob": torch.log(torch.rand(1, 10, 1)), "next": { "obs": torch.randn(1, 10, 3), "reward": torch.randn(1, 10, 1, requires_grad=True), @@ -15575,7 +15572,7 @@ def test_time_dim(self, adv, kwargs, shifted=True): td = TensorDict( { "obs": torch.randn(1, 10, 3), - "sample_log_prob": torch.log(torch.rand(1, 10, 1)), + "action_log_prob": torch.log(torch.rand(1, 10, 1)), "next": { "obs": torch.randn(1, 10, 3), "reward": torch.randn(1, 10, 1, requires_grad=True), @@ -15676,7 +15673,7 @@ def test_skip_existing( td = TensorDict( { "obs": torch.randn(1, 10, 3), - "sample_log_prob": torch.log(torch.rand(1, 10, 1)), + "action_log_prob": torch.log(torch.rand(1, 10, 1)), "state_value": torch.ones(1, 10, 1), "next": { "obs": torch.randn(1, 10, 3), @@ -15814,6 +15811,13 @@ def test_set_deprecated_keys(self, adv, kwargs): class TestBase: + @pytest.fixture(scope="class", autouse=True) + def _composite_log_prob(self): + setter = set_composite_lp_aggregate(False) + setter.set() + yield + setter.unset() + def test_decorators(self): class MyLoss(LossModule): def forward(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -16033,6 +16037,13 @@ class _AcceptedKeys: class TestUtils: + @pytest.fixture(scope="class", autouse=True) + def _composite_log_prob(self): + setter = set_composite_lp_aggregate(False) + setter.set() + yield + setter.unset() + @pytest.mark.parametrize("B", [None, (1, ), (4, ), (2, 2, ), (1, 2, 8, )]) # fmt: skip @pytest.mark.parametrize("T", [1, 10]) @pytest.mark.parametrize("device", get_default_devices()) @@ -16203,6 +16214,7 @@ def fun(a, b, time_dim=-2): (SoftUpdate, {"eps": 0.99}), ], ) +@set_composite_lp_aggregate(False) def test_updater_warning(updater, kwarg): with warnings.catch_warnings(): dqn = DQNLoss(torch.nn.Linear(3, 4), delay_value=True, action_space="one_hot") @@ -16215,6 +16227,13 @@ def test_updater_warning(updater, kwarg): class TestSingleCall: + @pytest.fixture(scope="class", autouse=True) + def _composite_log_prob(self): + setter = set_composite_lp_aggregate(False) + setter.set() + yield + setter.unset() + def _mock_value_net(self, has_target, value_key): model = nn.Linear(3, 1) module = TensorDictModule(model, in_keys=["obs"], out_keys=[value_key]) @@ -16267,6 +16286,7 @@ def test_single_call(self, has_target, value_key, single_call, detach_next=True) assert (value != value_).all() +@set_composite_lp_aggregate(False) def test_instantiate_with_different_keys(): loss_1 = DQNLoss( value_network=nn.Linear(3, 3), action_space="one_hot", delay_value=True @@ -16281,6 +16301,13 @@ def test_instantiate_with_different_keys(): class TestBuffer: + @pytest.fixture(scope="class", autouse=True) + def _composite_log_prob(self): + setter = set_composite_lp_aggregate(False) + setter.set() + yield + setter.unset() + # @pytest.mark.parametrize('dtype', (torch.double, torch.float, torch.half)) # def test_param_cast(self, dtype): # param = nn.Parameter(torch.zeros(3)) @@ -16390,6 +16417,7 @@ def __init__(self): TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5" ) @pytest.mark.skipif(IS_WINDOWS, reason="windows tests do not support compile") +@set_composite_lp_aggregate(False) def test_exploration_compile(): try: torch._dynamo.reset_code_caches() @@ -16456,6 +16484,7 @@ def func(t): assert it == exploration_type() +@set_composite_lp_aggregate(False) def test_loss_exploration(): class DummyLoss(LossModule): def forward(self, td, mode): diff --git a/torchrl/objectives/deprecated.py b/torchrl/objectives/deprecated.py index 28d664957c5..57cea971bc0 100644 --- a/torchrl/objectives/deprecated.py +++ b/torchrl/objectives/deprecated.py @@ -13,7 +13,7 @@ import torch from tensordict import TensorDict, TensorDictBase, TensorDictParams -from tensordict.nn import dispatch, TensorDictModule +from tensordict.nn import composite_lp_aggregate, dispatch, TensorDictModule from tensordict.utils import NestedKey from torch import Tensor @@ -121,12 +121,20 @@ class _AcceptedKeys: action: NestedKey = "action" state_action_value: NestedKey = "state_action_value" value: NestedKey = "state_value" - log_prob: NestedKey = "_log_prob" + log_prob: NestedKey | None = None priority: NestedKey = "td_error" reward: NestedKey = "reward" done: NestedKey = "done" terminated: NestedKey = "terminated" + def __post_init__(self): + if self.log_prob is None: + if composite_lp_aggregate(nowarn=True): + self.log_prob = "sample_log_prob" + else: + self.log_prob = "action_log_prob" + + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys delay_actor: bool = False default_value_estimator = ValueEstimators.TD0 @@ -358,12 +366,14 @@ def _actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]: tensordict_clone.select(*self.qvalue_network.in_keys, strict=False), self._cached_detach_qvalue_network_params, ) - state_action_value = tensordict_expand.get("state_action_value").squeeze(-1) + state_action_value = tensordict_expand.get( + self.tensor_keys.state_action_value + ).squeeze(-1) loss_actor = -( state_action_value - - self.alpha * tensordict_clone.get("sample_log_prob").squeeze(-1) + - self.alpha * tensordict_clone.get(self.tensor_keys.log_prob).squeeze(-1) ) - return loss_actor, tensordict_clone.get("sample_log_prob") + return loss_actor, tensordict_clone.get(self.tensor_keys.log_prob) def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor: tensordict_save = tensordict @@ -388,30 +398,33 @@ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor: ExplorationType.RANDOM ), self.target_actor_network_params.to_module(self.actor_network): self.actor_network(next_td) - sample_log_prob = next_td.get("sample_log_prob") + sample_log_prob = next_td.get(self.tensor_keys.log_prob) # get q-values next_td = self._vmap_qvalue_networkN0( next_td, selected_q_params, ) - state_action_value = next_td.get("state_action_value") + state_action_value = next_td.get(self.tensor_keys.state_action_value) if ( state_action_value.shape[-len(sample_log_prob.shape) :] != sample_log_prob.shape ): sample_log_prob = sample_log_prob.unsqueeze(-1) next_state_value = ( - next_td.get("state_action_value") - self.alpha * sample_log_prob + next_td.get(self.tensor_keys.state_action_value) + - self.alpha * sample_log_prob ) next_state_value = next_state_value.min(0)[0] - tensordict.set(("next", "state_value"), next_state_value) + tensordict.set(("next", self.tensor_keys.value), next_state_value) target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) tensordict_expand = self._vmap_qvalue_networkN0( tensordict.select(*self.qvalue_network.in_keys, strict=False), self.qvalue_network_params, ) - pred_val = tensordict_expand.get("state_action_value").squeeze(-1) + pred_val = tensordict_expand.get(self.tensor_keys.state_action_value).squeeze( + -1 + ) td_error = abs(pred_val - target_value) loss_qval = distance_loss( pred_val, diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 361fba322a3..b5b45393a06 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -317,7 +317,6 @@ def __post_init__(self): target_actor_network_params: TensorDictParams target_critic_network_params: TensorDictParams - @set_composite_lp_aggregate(False) def __init__( self, actor_network: ProbabilisticTensorDictSequential | None = None, @@ -487,7 +486,6 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: def reset(self) -> None: pass - @set_composite_lp_aggregate(False) def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: try: entropy = dist.entropy() @@ -496,20 +494,21 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: x = dist.rsample((self.samples_mc_entropy,)) else: x = dist.sample((self.samples_mc_entropy,)) - log_prob = dist.log_prob(x) - - if is_tensor_collection(log_prob): - if isinstance(self.tensor_keys.sample_log_prob, NestedKey): - log_prob = log_prob.get(self.tensor_keys.sample_log_prob) - else: - log_prob = log_prob.select(*self.tensor_keys.sample_log_prob) + with set_composite_lp_aggregate(False) if isinstance( + dist, CompositeDistribution + ) else contextlib.nullcontext(): + log_prob = dist.log_prob(x) + if is_tensor_collection(log_prob): + if isinstance(self.tensor_keys.sample_log_prob, NestedKey): + log_prob = log_prob.get(self.tensor_keys.sample_log_prob) + else: + log_prob = log_prob.select(*self.tensor_keys.sample_log_prob) entropy = -log_prob.mean(0) if is_tensor_collection(entropy): entropy = _sum_td_features(entropy) return entropy.unsqueeze(-1) - @set_composite_lp_aggregate(False) def _log_weight( self, tensordict: TensorDictBase ) -> Tuple[torch.Tensor, d.Distribution]: @@ -550,21 +549,22 @@ def _log_weight( ) log_prob = dist.log_prob(action) if is_composite: - if not is_tensor_collection(prev_log_prob): - # this isn't great, in general multihead actions should have a composite log-prob too - warnings.warn( - "You are using a composite distribution, yet your log-probability is a tensor. " - "Make sure you have called tensordict.nn.set_composite_lp_aggregate(False).set() at " - "the beginning of your script to get a proper composite log-prob.", - category=UserWarning, - ) - if ( - is_composite - and not is_tensor_collection(prev_log_prob) - and is_tensor_collection(log_prob) - ): - log_prob = _sum_td_features(log_prob) - log_prob.view_as(prev_log_prob) + with set_composite_lp_aggregate(False): + if not is_tensor_collection(prev_log_prob): + # this isn't great, in general multihead actions should have a composite log-prob too + warnings.warn( + "You are using a composite distribution, yet your log-probability is a tensor. " + "Make sure you have called tensordict.nn.set_composite_lp_aggregate(False).set() at " + "the beginning of your script to get a proper composite log-prob.", + category=UserWarning, + ) + if ( + is_composite + and not is_tensor_collection(prev_log_prob) + and is_tensor_collection(log_prob) + ): + log_prob = _sum_td_features(log_prob) + log_prob.view_as(prev_log_prob) log_weight = (log_prob - prev_log_prob).unsqueeze(-1) kl_approx = (prev_log_prob - log_prob).unsqueeze(-1) @@ -1215,11 +1215,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: self.actor_network ) if self.functional else contextlib.nullcontext(): current_dist = self.actor_network.get_dist(tensordict_copy) + is_composite = isinstance(current_dist, CompositeDistribution) try: kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist) except NotImplementedError: x = previous_dist.sample((self.samples_mc_kl,)) - with set_composite_lp_aggregate(False): + with set_composite_lp_aggregate( + False + ) if is_composite else contextlib.nullcontext(): previous_log_prob = previous_dist.log_prob(x) current_log_prob = current_dist.log_prob(x) if is_tensor_collection(previous_log_prob): diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index cff4a016105..bec40681f92 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -240,6 +240,7 @@ def __post_init__(self): else: self.sample_log_prob = "action_log_prob" + tensor_keys: _AcceptedKeys default_keys = _AcceptedKeys delay_actor: bool = False default_value_estimator = ValueEstimators.TD0