Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 15, 2025
1 parent 866ff1a commit d5d49da
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 99 deletions.
26 changes: 13 additions & 13 deletions .github/unittest/linux_sota/scripts/test_sota.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,19 +188,6 @@
ppo.collector.frames_per_batch=16 \
logger.mode=offline \
logger.backend=
""",
"dreamer": """python sota-implementations/dreamer/dreamer.py \
collector.total_frames=600 \
collector.init_random_frames=10 \
collector.frames_per_batch=200 \
env.n_parallel_envs=1 \
optimization.optim_steps_per_batch=1 \
logger.video=False \
logger.backend=csv \
replay_buffer.buffer_size=120 \
replay_buffer.batch_size=24 \
replay_buffer.batch_length=12 \
networks.rssm_hidden_dim=17
""",
"ddpg-single": """python sota-implementations/ddpg/ddpg.py \
collector.total_frames=48 \
Expand Down Expand Up @@ -289,6 +276,19 @@
logger.backend=
""",
"bandits": """python sota-implementations/bandits/dqn.py --n_steps=100
""",
"dreamer": """python sota-implementations/dreamer/dreamer.py \
collector.total_frames=600 \
collector.init_random_frames=10 \
collector.frames_per_batch=200 \
env.n_parallel_envs=1 \
optimization.optim_steps_per_batch=1 \
logger.video=False \
logger.backend=csv \
replay_buffer.buffer_size=120 \
replay_buffer.batch_size=24 \
replay_buffer.batch_length=12 \
networks.rssm_hidden_dim=17
""",
}

Expand Down
129 changes: 79 additions & 50 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import torch

from packaging import version, version as pack_version

from tensordict import assert_allclose_td, TensorDict, TensorDictBase
from tensordict._C import unravel_keys
from tensordict.nn import (
Expand All @@ -37,6 +36,7 @@
TensorDictSequential as Seq,
WrapModule,
)
from tensordict.nn.distributions.composite import _add_suffix
from tensordict.nn.utils import Buffer
from tensordict.utils import unravel_key
from torch import autograd, nn
Expand Down Expand Up @@ -199,6 +199,13 @@ def get_devices():


class LossModuleTestBase:
@pytest.fixture(scope="class", autouse=True)
def _composite_log_prob(self):
setter = set_composite_lp_aggregate(False)
setter.set()
yield
setter.unset()

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
assert hasattr(
Expand Down Expand Up @@ -3541,13 +3548,6 @@ def test_td3bc_reduction(self, reduction):
class TestSAC(LossModuleTestBase):
seed = 0

@pytest.fixture(scope="class", autouse=True)
def _composite_log_prob(self):
setter = set_composite_lp_aggregate(False)
setter.set()
yield
setter.unset()

def _create_mock_actor(
self,
batch=2,
Expand Down Expand Up @@ -4623,13 +4623,6 @@ def test_sac_reduction(self, reduction, version, composite_action_dist):
class TestDiscreteSAC(LossModuleTestBase):
seed = 0

@pytest.fixture(scope="class", autouse=True)
def _composite_log_prob(self):
setter = set_composite_lp_aggregate(False)
setter.set()
yield
setter.unset()

def _create_mock_actor(
self,
batch=2,
Expand Down Expand Up @@ -6786,7 +6779,7 @@ def test_redq_tensordict_keys(self, td_est):
"priority": "td_error",
"action": "action",
"value": "state_value",
"sample_log_prob": "sample_log_prob",
"sample_log_prob": "action_log_prob",
"state_action_value": "state_action_value",
"reward": "reward",
"done": "done",
Expand Down Expand Up @@ -6849,12 +6842,22 @@ def test_redq_notensordict(
actor_network=actor,
qvalue_network=qvalue,
)
loss.set_keys(
action=action_key,
reward=reward_key,
done=done_key,
terminated=terminated_key,
)
if deprec:
loss.set_keys(
action=action_key,
reward=reward_key,
done=done_key,
terminated=terminated_key,
log_prob=_add_suffix(action_key, "_log_prob"),
)
else:
loss.set_keys(
action=action_key,
reward=reward_key,
done=done_key,
terminated=terminated_key,
sample_log_prob=_add_suffix(action_key, "_log_prob"),
)

kwargs = {
action_key: td.get(action_key),
Expand Down Expand Up @@ -7916,13 +7919,6 @@ def test_dcql_reduction(self, reduction):
class TestPPO(LossModuleTestBase):
seed = 0

@pytest.fixture(scope="class", autouse=True)
def _composite_log_prob(self):
setter = set_composite_lp_aggregate(False)
setter.set()
yield
setter.unset()

def _create_mock_actor(
self,
batch=2,
Expand Down Expand Up @@ -8003,7 +7999,7 @@ def _create_mock_actor_value(
action_dim=4,
device="cpu",
composite_action_dist=False,
sample_log_prob_key="sample_log_prob",
sample_log_prob_key="action_log_prob",
):
# Actor
action_spec = Bounded(
Expand Down Expand Up @@ -8058,7 +8054,7 @@ def _create_mock_actor_value_shared(
action_dim=4,
device="cpu",
composite_action_dist=False,
sample_log_prob_key="sample_log_prob",
sample_log_prob_key="action_log_prob",
):
# Actor
action_spec = Bounded(
Expand Down Expand Up @@ -8123,7 +8119,7 @@ def _create_mock_data_ppo(
reward_key="reward",
done_key="done",
terminated_key="terminated",
sample_log_prob_key="sample_log_prob",
sample_log_prob_key="action_log_prob",
composite_action_dist=False,
):
# create a tensordict
Expand Down Expand Up @@ -8834,7 +8830,7 @@ def test_ppo_tensordict_keys_run(self, loss_class, advantage, td_est):
"advantage": "advantage_test",
"value_target": "value_target_test",
"value": "state_value_test",
"sample_log_prob": "sample_log_prob_test",
"sample_log_prob": "action_log_prob_test",
"action": "action_test",
}

Expand Down Expand Up @@ -9242,13 +9238,6 @@ def mixture_constructor(logits, loc, scale):
class TestA2C(LossModuleTestBase):
seed = 0

@pytest.fixture(scope="class", autouse=True)
def _composite_log_prob(self):
setter = set_composite_lp_aggregate(False)
setter.set()
yield
setter.unset()

def _create_mock_actor(
self,
batch=2,
Expand Down Expand Up @@ -9814,7 +9803,7 @@ def test_a2c_tensordict_keys_run(
value_key = "state_value_test"
action_key = "action_test"
reward_key = "reward_test"
sample_log_prob_key = "sample_log_prob_test"
sample_log_prob_key = "action_log_prob_test"
done_key = ("done", "test")
terminated_key = ("terminated", "test")

Expand Down Expand Up @@ -10258,7 +10247,7 @@ def test_reinforce_tensordict_keys(self, td_est):
"advantage": "advantage",
"value_target": "value_target",
"value": "state_value",
"sample_log_prob": "sample_log_prob",
"sample_log_prob": "action_log_prob",
"reward": "reward",
"done": "done",
"terminated": "terminated",
Expand Down Expand Up @@ -10316,7 +10305,7 @@ def _create_mock_common_layer_setup(
{
"obs": torch.randn(*batch, n_obs),
"action": torch.randn(*batch, n_act),
"sample_log_prob": torch.randn(*batch),
"action_log_prob": torch.randn(*batch),
"done": torch.zeros(*batch, 1, dtype=torch.bool),
"terminated": torch.zeros(*batch, 1, dtype=torch.bool),
"next": {
Expand Down Expand Up @@ -11788,7 +11777,7 @@ def _create_mock_common_layer_setup(
{
"obs": torch.randn(*batch, n_obs),
"action": torch.randn(*batch, n_act),
"sample_log_prob": torch.randn(*batch),
"action_log_prob": torch.randn(*batch),
"done": torch.zeros(*batch, 1, dtype=torch.bool),
"terminated": torch.zeros(*batch, 1, dtype=torch.bool),
"next": {
Expand Down Expand Up @@ -12604,7 +12593,7 @@ def _create_mock_common_layer_setup(
{
"obs": torch.randn(*batch, n_obs),
"action": torch.randn(*batch, n_act),
"sample_log_prob": torch.randn(*batch),
"action_log_prob": torch.randn(*batch),
"done": torch.zeros(*batch, 1, dtype=torch.bool),
"terminated": torch.zeros(*batch, 1, dtype=torch.bool),
"next": {
Expand Down Expand Up @@ -15228,6 +15217,7 @@ def test_successive_traj_gae(
["half", torch.half, "cpu"],
],
)
@set_composite_lp_aggregate(False)
def test_shared_params(dest, expected_dtype, expected_device):
if torch.cuda.device_count() == 0 and dest == "cuda":
pytest.skip("no cuda device available")
Expand Down Expand Up @@ -15332,6 +15322,13 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:


class TestAdv:
@pytest.fixture(scope="class", autouse=True)
def _composite_log_prob(self):
setter = set_composite_lp_aggregate(False)
setter.set()
yield
setter.unset()

@pytest.mark.parametrize(
"adv,kwargs",
[
Expand Down Expand Up @@ -15369,7 +15366,7 @@ def test_dispatch(
)
kwargs = {
"obs": torch.randn(1, 10, 3),
"sample_log_prob": torch.log(torch.rand(1, 10, 1)),
"action_log_prob": torch.log(torch.rand(1, 10, 1)),
"next_reward": torch.randn(1, 10, 1, requires_grad=True),
"next_done": torch.zeros(1, 10, 1, dtype=torch.bool),
"next_terminated": torch.zeros(1, 10, 1, dtype=torch.bool),
Expand Down Expand Up @@ -15431,7 +15428,7 @@ def test_diff_reward(
td = TensorDict(
{
"obs": torch.randn(1, 10, 3),
"sample_log_prob": torch.log(torch.rand(1, 10, 1)),
"action_log_prob": torch.log(torch.rand(1, 10, 1)),
"next": {
"obs": torch.randn(1, 10, 3),
"reward": torch.randn(1, 10, 1, requires_grad=True),
Expand Down Expand Up @@ -15504,7 +15501,7 @@ def test_non_differentiable(self, adv, shifted, kwargs):
td = TensorDict(
{
"obs": torch.randn(1, 10, 3),
"sample_log_prob": torch.log(torch.rand(1, 10, 1)),
"action_log_prob": torch.log(torch.rand(1, 10, 1)),
"next": {
"obs": torch.randn(1, 10, 3),
"reward": torch.randn(1, 10, 1, requires_grad=True),
Expand Down Expand Up @@ -15575,7 +15572,7 @@ def test_time_dim(self, adv, kwargs, shifted=True):
td = TensorDict(
{
"obs": torch.randn(1, 10, 3),
"sample_log_prob": torch.log(torch.rand(1, 10, 1)),
"action_log_prob": torch.log(torch.rand(1, 10, 1)),
"next": {
"obs": torch.randn(1, 10, 3),
"reward": torch.randn(1, 10, 1, requires_grad=True),
Expand Down Expand Up @@ -15676,7 +15673,7 @@ def test_skip_existing(
td = TensorDict(
{
"obs": torch.randn(1, 10, 3),
"sample_log_prob": torch.log(torch.rand(1, 10, 1)),
"action_log_prob": torch.log(torch.rand(1, 10, 1)),
"state_value": torch.ones(1, 10, 1),
"next": {
"obs": torch.randn(1, 10, 3),
Expand Down Expand Up @@ -15814,6 +15811,13 @@ def test_set_deprecated_keys(self, adv, kwargs):


class TestBase:
@pytest.fixture(scope="class", autouse=True)
def _composite_log_prob(self):
setter = set_composite_lp_aggregate(False)
setter.set()
yield
setter.unset()

def test_decorators(self):
class MyLoss(LossModule):
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
Expand Down Expand Up @@ -16033,6 +16037,13 @@ class _AcceptedKeys:


class TestUtils:
@pytest.fixture(scope="class", autouse=True)
def _composite_log_prob(self):
setter = set_composite_lp_aggregate(False)
setter.set()
yield
setter.unset()

@pytest.mark.parametrize("B", [None, (1, ), (4, ), (2, 2, ), (1, 2, 8, )]) # fmt: skip
@pytest.mark.parametrize("T", [1, 10])
@pytest.mark.parametrize("device", get_default_devices())
Expand Down Expand Up @@ -16203,6 +16214,7 @@ def fun(a, b, time_dim=-2):
(SoftUpdate, {"eps": 0.99}),
],
)
@set_composite_lp_aggregate(False)
def test_updater_warning(updater, kwarg):
with warnings.catch_warnings():
dqn = DQNLoss(torch.nn.Linear(3, 4), delay_value=True, action_space="one_hot")
Expand All @@ -16215,6 +16227,13 @@ def test_updater_warning(updater, kwarg):


class TestSingleCall:
@pytest.fixture(scope="class", autouse=True)
def _composite_log_prob(self):
setter = set_composite_lp_aggregate(False)
setter.set()
yield
setter.unset()

def _mock_value_net(self, has_target, value_key):
model = nn.Linear(3, 1)
module = TensorDictModule(model, in_keys=["obs"], out_keys=[value_key])
Expand Down Expand Up @@ -16267,6 +16286,7 @@ def test_single_call(self, has_target, value_key, single_call, detach_next=True)
assert (value != value_).all()


@set_composite_lp_aggregate(False)
def test_instantiate_with_different_keys():
loss_1 = DQNLoss(
value_network=nn.Linear(3, 3), action_space="one_hot", delay_value=True
Expand All @@ -16281,6 +16301,13 @@ def test_instantiate_with_different_keys():


class TestBuffer:
@pytest.fixture(scope="class", autouse=True)
def _composite_log_prob(self):
setter = set_composite_lp_aggregate(False)
setter.set()
yield
setter.unset()

# @pytest.mark.parametrize('dtype', (torch.double, torch.float, torch.half))
# def test_param_cast(self, dtype):
# param = nn.Parameter(torch.zeros(3))
Expand Down Expand Up @@ -16390,6 +16417,7 @@ def __init__(self):
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
)
@pytest.mark.skipif(IS_WINDOWS, reason="windows tests do not support compile")
@set_composite_lp_aggregate(False)
def test_exploration_compile():
try:
torch._dynamo.reset_code_caches()
Expand Down Expand Up @@ -16456,6 +16484,7 @@ def func(t):
assert it == exploration_type()


@set_composite_lp_aggregate(False)
def test_loss_exploration():
class DummyLoss(LossModule):
def forward(self, td, mode):
Expand Down
Loading

0 comments on commit d5d49da

Please sign in to comment.