Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 17, 2025
1 parent 08c1fd1 commit 585e042
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 5 deletions.
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
TargetReturn,
TensorDictPrimer,
TimeMaxPool,
Tokenizer,
ToTensorImage,
TrajCounter,
Transform,
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
TargetReturn,
TensorDictPrimer,
TimeMaxPool,
Tokenizer,
ToTensorImage,
TrajCounter,
Transform,
Expand Down
80 changes: 75 additions & 5 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4426,8 +4426,8 @@ class UnaryTransform(Transform):
Args:
in_keys (sequence of NestedKey): the keys of inputs to the unary operation.
out_keys (sequence of NestedKey): the keys of the outputs of the unary operation.
in_keys_inv (sequence of NestedKey): the keys of inputs to the unary operation during inverse call.
out_keys_inv (sequence of NestedKey): the keys of the outputs of the unary operation durin inverse call.
in_keys_inv (sequence of NestedKey, optional): the keys of inputs to the unary operation during inverse call.
out_keys_inv (sequence of NestedKey, optional): the keys of the outputs of the unary operation durin inverse call.
Keyword Args:
fn (Callable): the function to use as the unary operation. If it accepts
Expand Down Expand Up @@ -4569,7 +4569,6 @@ def transform_input_spec(self, input_spec: Composite) -> Composite:
input_spec["full_state_spec"],
test_input_spec,
)
print(input_spec)
return input_spec

def transform_output_spec(self, output_spec: Composite) -> Composite:
Expand Down Expand Up @@ -4649,8 +4648,8 @@ class Hash(UnaryTransform):
Args:
in_keys (sequence of NestedKey): the keys of the values to hash.
out_keys (sequence of NestedKey): the keys of the resulting hashes.
in_keys_inv (sequence of NestedKey): the keys of the values to hash during inv call.
out_keys_inv (sequence of NestedKey): the keys of the resulting hashes during inv call.
in_keys_inv (sequence of NestedKey, optional): the keys of the values to hash during inv call.
out_keys_inv (sequence of NestedKey, optional): the keys of the resulting hashes during inv call.
Keyword Args:
hash_fn (Callable, optional): the hash function to use. If ``seed`` is given,
Expand Down Expand Up @@ -4801,6 +4800,77 @@ def reproducible_hash(cls, string, seed=None):
return torch.frombuffer(hash_bytes, dtype=torch.uint8)


class Tokenizer(UnaryTransform):
r"""Applies a tokenization operation on the specified inputs.
Args:
in_keys (sequence of NestedKey): the keys of inputs to the tokenization operation.
out_keys (sequence of NestedKey): the keys of the outputs of the tokenization operation.
in_keys_inv (sequence of NestedKey, optional): the keys of inputs to the tokenization operation during inverse call.
out_keys_inv (sequence of NestedKey, optional): the keys of the outputs of the tokenization operation during inverse call.
Keyword Args:
tokenizer (transformers.PretrainedTokenizerBase or str, optional): the tokenizer to use. If ``None``,
"bert-base-uncased" will be used by default. If a string is provided, it should be the name of a
pre-trained tokenizer.
use_raw_nontensor (bool, optional): if ``False``, data is extracted from
:class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before the tokenization
function is called on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
inputs are given directly to the tokenization function, which must support those inputs. Default is ``False``.
additional_tokens (List[str], optional): list of additional tokens to add to the tokenizer's vocabulary.
"""

def __init__(
self,
in_keys: Sequence[NestedKey],
out_keys: Sequence[NestedKey],
in_keys_inv: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None,
*,
tokenizer: "transformers.PretrainedTokenizerBase" = None, # noqa: F821
use_raw_nontensor: bool = False,
additional_tokens: List[str] | None = None,
):
if tokenizer is None:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
elif isinstance(tokenizer, str):
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(tokenizer)

self.tokenizer = tokenizer
if additional_tokens:
self.tokenizer.add_tokens(additional_tokens)
super().__init__(
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
out_keys_inv=out_keys_inv,
fn=self.call_tokenizer_fn,
use_raw_nontensor=use_raw_nontensor,
)

@property
def device(self):
if "_device" in self.__dict__:
return self._device
parent = self.parent
if parent is None:
return None
device = parent.device
self._device = device
return device

def call_tokenizer_fn(self, value: str | List[str]):
device = self.device
out = self.tokenizer.encode(value, return_tensors="pt")
if device is not None and out.device != device:
out = out.to(device)
return out


class Stack(Transform):
"""Stacks tensors and tensordicts.
Expand Down

0 comments on commit 585e042

Please sign in to comment.