Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 17, 2025
2 parents 42ce732 + 043d578 commit f70bc1b
Show file tree
Hide file tree
Showing 7 changed files with 446 additions and 23 deletions.
5 changes: 4 additions & 1 deletion test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,17 +1070,20 @@ def _step(

class CountingEnvWithString(CountingEnv):
def __init__(self, *args, **kwargs):
self.max_size = kwargs.pop("max_size", 30)
self.min_size = kwargs.pop("min_size", 4)
super().__init__(*args, **kwargs)
self.observation_spec.set(
"string",
NonTensor(
shape=self.batch_size,
device=self.device,
example_data=self.get_random_string(),
),
)

def get_random_string(self):
size = random.randint(4, 30)
size = random.randint(self.min_size, self.max_size)
return "".join(random.choice(string.ascii_lowercase) for _ in range(size))

def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
Expand Down
29 changes: 21 additions & 8 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,12 +1402,13 @@ def test_multionehot(self, shape1, shape2):
assert spec2.zero().shape == spec2.shape

def test_non_tensor(self):
spec = NonTensor((3, 4), device="cpu")
spec = NonTensor((3, 4), device="cpu", example_data="example_data")
assert (
spec.expand(2, 3, 4)
== spec.expand((2, 3, 4))
== NonTensor((2, 3, 4), device="cpu")
== NonTensor((2, 3, 4), device="cpu", example_data="example_data")
)
assert spec.expand(2, 3, 4).example_data == "example_data"

@pytest.mark.parametrize("shape1", [None, (), (5,)])
@pytest.mark.parametrize("shape2", [(), (10,)])
Expand Down Expand Up @@ -1607,9 +1608,10 @@ def test_multionehot(
assert spec is not spec.clone()

def test_non_tensor(self):
spec = NonTensor(shape=(3, 4), device="cpu")
spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data")
assert spec.clone() == spec
assert spec.clone() is not spec
assert spec.clone().example_data == "example_data"

@pytest.mark.parametrize("shape1", [None, (), (5,)])
def test_onehot(
Expand Down Expand Up @@ -1840,9 +1842,10 @@ def test_multionehot(
spec.unbind(-1)

def test_non_tensor(self):
spec = NonTensor(shape=(3, 4), device="cpu")
spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data")
assert spec.unbind(1)[0] == spec[:, 0]
assert spec.unbind(1)[0] is not spec[:, 0]
assert spec.unbind(1)[0].example_data == "example_data"

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_onehot(
Expand Down Expand Up @@ -2001,8 +2004,9 @@ def test_multionehot(self, shape1, device):
assert spec.to(device).device == device

def test_non_tensor(self, device):
spec = NonTensor(shape=(3, 4), device="cpu")
spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data")
assert spec.to(device).device == device
assert spec.to(device).example_data == "example_data"

@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
def test_onehot(self, shape1, device):
Expand Down Expand Up @@ -2262,13 +2266,14 @@ def test_stack_multionehot_zero(self, shape, stack_dim):
assert r.shape == c.shape

def test_stack_non_tensor(self, shape, stack_dim):
spec0 = NonTensor(shape=shape, device="cpu")
spec1 = NonTensor(shape=shape, device="cpu")
spec0 = NonTensor(shape=shape, device="cpu", example_data="example_data")
spec1 = NonTensor(shape=shape, device="cpu", example_data="example_data")
new_spec = torch.stack([spec0, spec1], stack_dim)
shape_insert = list(shape)
shape_insert.insert(stack_dim, 2)
assert new_spec.shape == torch.Size(shape_insert)
assert new_spec.device == torch.device("cpu")
assert new_spec.example_data == "example_data"

def test_stack_onehot(self, shape, stack_dim):
n = 5
Expand Down Expand Up @@ -3642,10 +3647,18 @@ def test_expand(self):

class TestNonTensorSpec:
def test_sample(self):
nts = NonTensor(shape=(3, 4))
nts = NonTensor(shape=(3, 4), example_data="example_data")
assert nts.one((2,)).shape == (2, 3, 4)
assert nts.rand((2,)).shape == (2, 3, 4)
assert nts.zero((2,)).shape == (2, 3, 4)
assert nts.one((2,)).data == "example_data"
assert nts.rand((2,)).data == "example_data"
assert nts.zero((2,)).data == "example_data"

def test_example_data_ineq(self):
nts0 = NonTensor(shape=(3, 4), example_data="example_data")
nts1 = NonTensor(shape=(3, 4), example_data="example_data 2")
assert nts0 != nts1


@pytest.mark.skipif(not torch.cuda.is_available(), reason="not cuda device")
Expand Down
219 changes: 218 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@
TargetReturn,
TensorDictPrimer,
TimeMaxPool,
Tokenizer,
ToTensorImage,
TrajCounter,
TransformedEnv,
Expand Down Expand Up @@ -2420,7 +2421,223 @@ def test_transform_rb(self, rbclass):
assert ("next", "observation") in td.keys(True)

def test_transform_inverse(self):
raise pytest.skip("No inverse for Hash")
env = CountingEnv()
env = env.append_transform(
Hash(
in_keys=[],
out_keys=[],
in_keys_inv=["action"],
out_keys_inv=["action_hash"],
)
)
assert "action_hash" in env.action_keys
r = env.rollout(3)
env.check_env_specs()
assert "action_hash" in r
assert isinstance(r[0]["action_hash"], torch.Tensor)


class TestTokenizer(TransformBase):
@pytest.mark.parametrize("datatype", ["str", "NonTensorStack"])
def test_transform_no_env(self, datatype):
if datatype == "str":
obs = "abcdefg"
elif datatype == "NonTensorStack":
obs = torch.stack(
[
NonTensorData(data="abcde"),
NonTensorData(data="fghij"),
NonTensorData(data="klmno"),
]
)
else:
raise RuntimeError(f"please add a test case for datatype {datatype}")

td = TensorDict(
{
"observation": obs,
}
)

t = Tokenizer(in_keys=["observation"], out_keys=["tokens"])
td_tokenized = t(td)
t_inv = Tokenizer([], [], in_keys_inv=["tokens"], out_keys_inv=["observation"])
td_recon = t_inv.inv(td_tokenized.clone().exclude("observation"))
assert td_tokenized.get("observation") is td.get("observation")
assert td_recon["observation"] == td["observation"]

@pytest.mark.parametrize("datatype", ["str"])
def test_single_trans_env_check(self, datatype):
if datatype == "str":
t = Tokenizer(
in_keys=["string"],
out_keys=["tokens"],
max_length=5,
)
base_env = CountingEnvWithString(max_size=4, min_size=4)
env = TransformedEnv(base_env, t)
check_env_specs(env, return_contiguous=False)

@pytest.mark.parametrize("datatype", ["str"])
def test_serial_trans_env_check(self, datatype):
def make_env():
if datatype == "str":
t = Tokenizer(
in_keys=["string"],
out_keys=["tokens"],
max_length=5,
)
base_env = CountingEnvWithString(max_size=4, min_size=4)

return TransformedEnv(base_env, t)

env = SerialEnv(2, make_env)
check_env_specs(env, return_contiguous=False)

@pytest.mark.parametrize("datatype", ["str"])
def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv, datatype):
def make_env():
if datatype == "str":
t = Tokenizer(
in_keys=["string"],
out_keys=["tokens"],
max_length=5,
)
base_env = CountingEnvWithString(max_size=4, min_size=4)
return TransformedEnv(base_env, t)

env = maybe_fork_ParallelEnv(2, make_env)
try:
check_env_specs(env, return_contiguous=False)
finally:
try:
env.close()
except RuntimeError:
pass

@pytest.mark.parametrize("datatype", ["str"])
def test_trans_serial_env_check(self, datatype):
if datatype == "str":
t = Tokenizer(
in_keys=["string"],
out_keys=["tokens"],
max_length=5,
)
base_env = partial(CountingEnvWithString, max_size=4, min_size=4)

env = TransformedEnv(SerialEnv(2, base_env), t)
check_env_specs(env, return_contiguous=False)

@pytest.mark.parametrize("datatype", ["str"])
def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv, datatype):
if datatype == "str":
t = Tokenizer(
in_keys=["string"],
out_keys=["tokens"],
max_length=5,
)
base_env = partial(CountingEnvWithString, max_size=4, min_size=4)

env = TransformedEnv(maybe_fork_ParallelEnv(2, base_env), t)
try:
check_env_specs(env, return_contiguous=False)
finally:
try:
env.close()
except RuntimeError:
pass

@pytest.mark.parametrize("datatype", ["str"])
def test_transform_compose(self, datatype):
if datatype == "str":
obs = "abcdefg"

td = TensorDict(
{
"observation": obs,
}
)
t = Tokenizer(
in_keys=["observation"],
out_keys=["tokens"],
max_length=5,
)
t = Compose(t)
td_tokenized = t(td)

assert td_tokenized["observation"] is td["observation"]
assert td_tokenized["tokens"] == t[0].tokenizer(obs, return_tensor="pt")

# TODO
def test_transform_model(self):
t = Hash(
in_keys=[("next", "observation"), ("observation",)],
out_keys=[("next", "hashing"), ("hashing",)],
hash_fn=hash,
)
model = nn.Sequential(t, nn.Identity())
td = TensorDict(
{("next", "observation"): torch.randn(3), "observation": torch.randn(3)}, []
)
td_out = model(td)
assert ("next", "hashing") in td_out.keys(True)
assert ("hashing",) in td_out.keys(True)
assert td_out["next", "hashing"] == hash(td["next", "observation"])
assert td_out["hashing"] == hash(td["observation"])

@pytest.mark.skipif(not _has_gym, reason="Gym not found")
def test_transform_env(self):
t = Hash(
in_keys=["observation"],
out_keys=["hashing"],
hash_fn=hash,
)
env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), t)
assert env.observation_spec["hashing"]
assert "observation" in env.observation_spec
assert "observation" in env.base_env.observation_spec
check_env_specs(env)

@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
def test_transform_rb(self, rbclass):
t = Hash(
in_keys=[("next", "observation"), ("observation",)],
out_keys=[("next", "hashing"), ("hashing",)],
hash_fn=lambda x: [hash(x[0]), hash(x[1])],
)
rb = rbclass(storage=LazyTensorStorage(10))
rb.append_transform(t)
td = TensorDict(
{
"observation": torch.randn(3, 4),
"next": TensorDict(
{"observation": torch.randn(3, 4)},
[],
),
},
[],
).expand(10)
rb.extend(td)
td = rb.sample(2)
assert "hashing" in td.keys()
assert "observation" in td.keys()
assert ("next", "observation") in td.keys(True)

def test_transform_inverse(self):
env = CountingEnv()
env = env.append_transform(
Hash(
in_keys=[],
out_keys=[],
in_keys_inv=["action"],
out_keys_inv=["action_hash"],
)
)
assert "action_hash" in env.action_keys
r = env.rollout(3)
env.check_env_specs()
assert "action_hash" in r
assert isinstance(r[0]["action_hash"], torch.Tensor)


class TestStack(TransformBase):
Expand Down
17 changes: 17 additions & 0 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2452,6 +2452,8 @@ class NonTensor(TensorSpec):
(same will go for :meth:`.zero` and :meth:`.one`).
"""

example_data: Any = None

def __init__(
self,
shape: Union[torch.Size, int] = _DEFAULT_SHAPE,
Expand All @@ -2470,6 +2472,11 @@ def __init__(
)
self.example_data = example_data

def __eq__(self, other):
eq = super().__eq__(other)
eq = eq & (self.example_data == getattr(other, "example_data", None))
return eq

def cardinality(self) -> Any:
raise RuntimeError("Cannot enumerate a NonTensorSpec.")

Expand Down Expand Up @@ -2555,6 +2562,16 @@ def expand(self, *shape):
shape=shape, device=self.device, dtype=None, example_data=self.example_data
)

def unsqueeze(self, dim: int) -> NonTensor:
unsq = super().unsqueeze(dim=dim)
unsq.example_data = self.example_data
return unsq

def squeeze(self, dim: int | None = None) -> NonTensor:
sq = super().squeeze(dim=dim)
sq.example_data = self.example_data
return sq

def _reshape(self, shape):
return self.__class__(
shape=shape,
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
TargetReturn,
TensorDictPrimer,
TimeMaxPool,
Tokenizer,
ToTensorImage,
TrajCounter,
Transform,
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
TargetReturn,
TensorDictPrimer,
TimeMaxPool,
Tokenizer,
ToTensorImage,
TrajCounter,
Transform,
Expand Down
Loading

0 comments on commit f70bc1b

Please sign in to comment.