From 585e0422a8a39d973137f69136b7b1834d944742 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 17 Jan 2025 13:29:48 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- torchrl/envs/__init__.py | 1 + torchrl/envs/transforms/__init__.py | 1 + torchrl/envs/transforms/transforms.py | 80 +++++++++++++++++++++++++-- 3 files changed, 77 insertions(+), 5 deletions(-) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 3a4cde38aa2..fed73755502 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -94,6 +94,7 @@ TargetReturn, TensorDictPrimer, TimeMaxPool, + Tokenizer, ToTensorImage, TrajCounter, Transform, diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index a25c676e378..7ee142fe811 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -55,6 +55,7 @@ TargetReturn, TensorDictPrimer, TimeMaxPool, + Tokenizer, ToTensorImage, TrajCounter, Transform, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 50d2762ad0b..2103e5717e3 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4426,8 +4426,8 @@ class UnaryTransform(Transform): Args: in_keys (sequence of NestedKey): the keys of inputs to the unary operation. out_keys (sequence of NestedKey): the keys of the outputs of the unary operation. - in_keys_inv (sequence of NestedKey): the keys of inputs to the unary operation during inverse call. - out_keys_inv (sequence of NestedKey): the keys of the outputs of the unary operation durin inverse call. + in_keys_inv (sequence of NestedKey, optional): the keys of inputs to the unary operation during inverse call. + out_keys_inv (sequence of NestedKey, optional): the keys of the outputs of the unary operation durin inverse call. Keyword Args: fn (Callable): the function to use as the unary operation. If it accepts @@ -4569,7 +4569,6 @@ def transform_input_spec(self, input_spec: Composite) -> Composite: input_spec["full_state_spec"], test_input_spec, ) - print(input_spec) return input_spec def transform_output_spec(self, output_spec: Composite) -> Composite: @@ -4649,8 +4648,8 @@ class Hash(UnaryTransform): Args: in_keys (sequence of NestedKey): the keys of the values to hash. out_keys (sequence of NestedKey): the keys of the resulting hashes. - in_keys_inv (sequence of NestedKey): the keys of the values to hash during inv call. - out_keys_inv (sequence of NestedKey): the keys of the resulting hashes during inv call. + in_keys_inv (sequence of NestedKey, optional): the keys of the values to hash during inv call. + out_keys_inv (sequence of NestedKey, optional): the keys of the resulting hashes during inv call. Keyword Args: hash_fn (Callable, optional): the hash function to use. If ``seed`` is given, @@ -4801,6 +4800,77 @@ def reproducible_hash(cls, string, seed=None): return torch.frombuffer(hash_bytes, dtype=torch.uint8) +class Tokenizer(UnaryTransform): + r"""Applies a tokenization operation on the specified inputs. + + Args: + in_keys (sequence of NestedKey): the keys of inputs to the tokenization operation. + out_keys (sequence of NestedKey): the keys of the outputs of the tokenization operation. + in_keys_inv (sequence of NestedKey, optional): the keys of inputs to the tokenization operation during inverse call. + out_keys_inv (sequence of NestedKey, optional): the keys of the outputs of the tokenization operation during inverse call. + + Keyword Args: + tokenizer (transformers.PretrainedTokenizerBase or str, optional): the tokenizer to use. If ``None``, + "bert-base-uncased" will be used by default. If a string is provided, it should be the name of a + pre-trained tokenizer. + use_raw_nontensor (bool, optional): if ``False``, data is extracted from + :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before the tokenization + function is called on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` + inputs are given directly to the tokenization function, which must support those inputs. Default is ``False``. + additional_tokens (List[str], optional): list of additional tokens to add to the tokenizer's vocabulary. + """ + + def __init__( + self, + in_keys: Sequence[NestedKey], + out_keys: Sequence[NestedKey], + in_keys_inv: Sequence[NestedKey] | None = None, + out_keys_inv: Sequence[NestedKey] | None = None, + *, + tokenizer: "transformers.PretrainedTokenizerBase" = None, # noqa: F821 + use_raw_nontensor: bool = False, + additional_tokens: List[str] | None = None, + ): + if tokenizer is None: + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + elif isinstance(tokenizer, str): + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(tokenizer) + + self.tokenizer = tokenizer + if additional_tokens: + self.tokenizer.add_tokens(additional_tokens) + super().__init__( + in_keys=in_keys, + out_keys=out_keys, + in_keys_inv=in_keys_inv, + out_keys_inv=out_keys_inv, + fn=self.call_tokenizer_fn, + use_raw_nontensor=use_raw_nontensor, + ) + + @property + def device(self): + if "_device" in self.__dict__: + return self._device + parent = self.parent + if parent is None: + return None + device = parent.device + self._device = device + return device + + def call_tokenizer_fn(self, value: str | List[str]): + device = self.device + out = self.tokenizer.encode(value, return_tensors="pt") + if device is not None and out.device != device: + out = out.to(device) + return out + + class Stack(Transform): """Stacks tensors and tensordicts.