Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 15, 2025
2 parents e94a40e + 0b0993e commit ae13f64
Show file tree
Hide file tree
Showing 31 changed files with 765 additions and 224 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,7 @@ to be able to create this other composition:
GrayScale
InitTracker
KLRewardTransform
LineariseReward
NoopResetEnv
ObservationNorm
ObservationTransform
Expand Down
83 changes: 52 additions & 31 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import itertools
import operator
import os

import sys
import warnings
from copy import deepcopy
Expand All @@ -23,13 +22,15 @@
from tensordict import assert_allclose_td, TensorDict, TensorDictBase
from tensordict._C import unravel_keys
from tensordict.nn import (
composite_lp_aggregate,
CompositeDistribution,
InteractionType,
NormalParamExtractor,
ProbabilisticTensorDictModule,
ProbabilisticTensorDictModule as ProbMod,
ProbabilisticTensorDictSequential,
ProbabilisticTensorDictSequential as ProbSeq,
set_composite_lp_aggregate,
TensorDictModule,
TensorDictModule as Mod,
TensorDictSequential,
Expand Down Expand Up @@ -3540,6 +3541,13 @@ def test_td3bc_reduction(self, reduction):
class TestSAC(LossModuleTestBase):
seed = 0

@pytest.fixture(scope="class", autouse=True)
def _composite_log_prob(self):
setter = set_composite_lp_aggregate(False)
setter.set()
yield
setter.unset()

def _create_mock_actor(
self,
batch=2,
Expand All @@ -3563,7 +3571,6 @@ def _create_mock_actor(
distribution_map={
"action1": TanhNormal,
},
aggregate_probabilities=True,
)
module_out_keys = [
("params", "action1", "loc"),
Expand All @@ -3583,6 +3590,7 @@ def _create_mock_actor(
out_keys=[action_key],
spec=action_spec,
)
assert actor.log_prob_keys
return actor.to(device)

def _create_mock_qvalue(
Expand Down Expand Up @@ -3688,7 +3696,6 @@ def forward(self, obs, act):
distribution_map={
"action1": TanhNormal,
},
aggregate_probabilities=True,
)
module_out_keys = [
("params", "action1", "loc"),
Expand Down Expand Up @@ -4342,7 +4349,7 @@ def test_sac_tensordict_keys(self, td_est, version, composite_action_dist):
"value": "state_value",
"state_action_value": "state_action_value",
"action": "action",
"log_prob": "sample_log_prob",
"log_prob": "action_log_prob",
"reward": "reward",
"done": "done",
"terminated": "terminated",
Expand Down Expand Up @@ -4616,6 +4623,13 @@ def test_sac_reduction(self, reduction, version, composite_action_dist):
class TestDiscreteSAC(LossModuleTestBase):
seed = 0

@pytest.fixture(scope="class", autouse=True)
def _composite_log_prob(self):
setter = set_composite_lp_aggregate(False)
setter.set()
yield
setter.unset()

def _create_mock_actor(
self,
batch=2,
Expand Down Expand Up @@ -7902,6 +7916,13 @@ def test_dcql_reduction(self, reduction):
class TestPPO(LossModuleTestBase):
seed = 0

@pytest.fixture(scope="class", autouse=True)
def _composite_log_prob(self):
setter = set_composite_lp_aggregate(False)
setter.set()
yield
setter.unset()

def _create_mock_actor(
self,
batch=2,
Expand All @@ -7910,9 +7931,8 @@ def _create_mock_actor(
device="cpu",
action_key=None,
observation_key="observation",
sample_log_prob_key="sample_log_prob",
sample_log_prob_key=None,
composite_action_dist=False,
aggregate_probabilities=None,
):
# Actor
action_spec = Bounded(
Expand All @@ -7934,7 +7954,6 @@ def _create_mock_actor(
"action1": action_key,
},
log_prob_key=sample_log_prob_key,
aggregate_probabilities=aggregate_probabilities,
)
module_out_keys = [
("params", "action1", "loc"),
Expand Down Expand Up @@ -8006,7 +8025,6 @@ def _create_mock_actor_value(
"action1": ("action", "action1"),
},
log_prob_key=sample_log_prob_key,
aggregate_probabilities=True,
)
module_out_keys = [
("params", "action1", "loc"),
Expand Down Expand Up @@ -8063,7 +8081,6 @@ def _create_mock_actor_value_shared(
"action1": ("action", "action1"),
},
log_prob_key=sample_log_prob_key,
aggregate_probabilities=True,
)
module_out_keys = [
("params", "action1", "loc"),
Expand Down Expand Up @@ -8182,7 +8199,8 @@ def _create_seq_mock_data_ppo(
if composite_action_dist:
sample_log_prob_key = ("action", "action1_log_prob")
else:
sample_log_prob_key = "sample_log_prob"
# conforming to composite_lp_aggregate(False)
sample_log_prob_key = "action_log_prob"

if action_key is None:
if composite_action_dist:
Expand Down Expand Up @@ -8285,6 +8303,7 @@ def test_ppo(
if advantage is not None:
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
if advantage is not None:
assert not composite_lp_aggregate()
advantage(td)
else:
if td_est is not None:
Expand Down Expand Up @@ -8344,7 +8363,6 @@ def test_ppo_composite_no_aggregate(
actor = self._create_mock_actor(
device=device,
composite_action_dist=True,
aggregate_probabilities=False,
)
value = self._create_mock_value(device=device)
if advantage == "gae":
Expand Down Expand Up @@ -8764,6 +8782,7 @@ def zero_param(p):
)
@pytest.mark.parametrize("composite_action_dist", [True, False])
def test_ppo_tensordict_keys(self, loss_class, td_est, composite_action_dist):
assert not composite_lp_aggregate()
actor = self._create_mock_actor(composite_action_dist=composite_action_dist)
value = self._create_mock_value()

Expand All @@ -8773,8 +8792,10 @@ def test_ppo_tensordict_keys(self, loss_class, td_est, composite_action_dist):
"advantage": "advantage",
"value_target": "value_target",
"value": "state_value",
"sample_log_prob": "sample_log_prob",
"action": "action",
"sample_log_prob": "action_log_prob"
if not composite_action_dist
else ("action", "action1_log_prob"),
"action": "action" if not composite_action_dist else ("action", "action1"),
"reward": "reward",
"done": "done",
"terminated": "terminated",
Expand Down Expand Up @@ -9160,9 +9181,6 @@ def mixture_constructor(logits, loc, scale):
"Kumaraswamy": ("agent1", "action"),
"mixture": ("agent2", "action"),
},
aggregate_probabilities=False,
include_sum=False,
inplace=True,
)
policy = ProbSeq(
make_params,
Expand All @@ -9181,15 +9199,11 @@ def mixture_constructor(logits, loc, scale):
# We want to make sure there is no warning
td = policy(TensorDict(batch_size=[4]))
assert isinstance(
policy.get_dist(td).log_prob(
td, aggregate_probabilities=False, inplace=False, include_sum=False
),
policy.get_dist(td).log_prob(td),
TensorDict,
)
assert isinstance(
policy.log_prob(
td, aggregate_probabilities=False, inplace=False, include_sum=False
),
policy.log_prob(td),
TensorDict,
)
value_operator = Seq(
Expand Down Expand Up @@ -9226,6 +9240,13 @@ def mixture_constructor(logits, loc, scale):
class TestA2C(LossModuleTestBase):
seed = 0

@pytest.fixture(scope="class", autouse=True)
def _composite_log_prob(self):
setter = set_composite_lp_aggregate(False)
setter.set()
yield
setter.unset()

def _create_mock_actor(
self,
batch=2,
Expand All @@ -9234,8 +9255,8 @@ def _create_mock_actor(
device="cpu",
action_key="action",
observation_key="observation",
sample_log_prob_key="sample_log_prob",
composite_action_dist=False,
sample_log_prob_key=None,
):
# Actor
action_spec = Bounded(
Expand All @@ -9253,8 +9274,6 @@ def _create_mock_actor(
name_map={
"action1": (action_key, "action1"),
},
log_prob_key=sample_log_prob_key,
aggregate_probabilities=True,
)
module_out_keys = [
("params", "action1", "loc"),
Expand Down Expand Up @@ -9304,7 +9323,6 @@ def _create_mock_common_layer_setup(
n_hidden=2,
T=10,
composite_action_dist=False,
sample_log_prob_key="sample_log_prob",
):
common_net = MLP(
num_cells=ncells,
Expand All @@ -9330,7 +9348,7 @@ def _create_mock_common_layer_setup(
{
"obs": torch.randn(*batch, n_obs),
"action": {"action1": action} if composite_action_dist else action,
"sample_log_prob": torch.randn(*batch),
"action_log_prob": torch.randn(*batch),
"done": torch.zeros(*batch, 1, dtype=torch.bool),
"terminated": torch.zeros(*batch, 1, dtype=torch.bool),
"next": {
Expand All @@ -9354,8 +9372,6 @@ def _create_mock_common_layer_setup(
name_map={
"action1": ("action", "action1"),
},
log_prob_key=sample_log_prob_key,
aggregate_probabilities=True,
)
module_out_keys = [
("params", "action1", "loc"),
Expand Down Expand Up @@ -9396,7 +9412,7 @@ def _create_seq_mock_data_a2c(
reward_key="reward",
done_key="done",
terminated_key="terminated",
sample_log_prob_key="sample_log_prob",
sample_log_prob_key="action_log_prob",
composite_action_dist=False,
):
# create a tensordict
Expand Down Expand Up @@ -9528,6 +9544,11 @@ def set_requires_grad(tensor, requires_grad):

td = td.exclude(loss_fn.tensor_keys.value_target)
if advantage is not None:
advantage.set_keys(
sample_log_prob=actor.log_prob_keys
if composite_action_dist
else "action_log_prob"
)
advantage(td)
elif td_est is not None:
loss_fn.make_value_estimator(td_est)
Expand Down Expand Up @@ -9747,7 +9768,7 @@ def test_a2c_tensordict_keys(self, td_est, composite_action_dist):
"reward": "reward",
"done": "done",
"terminated": "terminated",
"sample_log_prob": "sample_log_prob",
"sample_log_prob": "action_log_prob",
}

self.tensordict_keys_test(
Expand Down
Loading

0 comments on commit ae13f64

Please sign in to comment.