From ff1ff7e9cf13026f544dd7565edac1acc6c458ec Mon Sep 17 00:00:00 2001 From: Faury Louis Date: Mon, 13 Jan 2025 17:22:16 +0100 Subject: [PATCH 1/2] [Feature] Linearise reward transform (#2681) Co-authored-by: Louis Faury Co-authored-by: Vincent Moens --- docs/source/reference/envs.rst | 1 + test/test_transforms.py | 331 +++++++++++++++++++++++++- torchrl/data/__init__.py | 1 + torchrl/envs/__init__.py | 1 + torchrl/envs/transforms/__init__.py | 1 + torchrl/envs/transforms/transforms.py | 121 +++++++++- 6 files changed, 453 insertions(+), 3 deletions(-) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 065d6a2e3d4..ede1421ffc9 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -829,6 +829,7 @@ to be able to create this other composition: GrayScale InitTracker KLRewardTransform + LineariseReward NoopResetEnv ObservationNorm ObservationTransform diff --git a/test/test_transforms.py b/test/test_transforms.py index 44ebce72c5c..7a01acdaeef 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -84,6 +84,7 @@ from torchrl._utils import _replace_last, prod from torchrl.data import ( Bounded, + BoundedContinuous, Categorical, Composite, LazyTensorStorage, @@ -92,6 +93,7 @@ TensorSpec, TensorStorage, Unbounded, + UnboundedContinuous, ) from torchrl.envs import ( ActionMask, @@ -117,6 +119,7 @@ GrayScale, gSDENoise, InitTracker, + LineariseRewards, MultiStepTransform, NoopResetEnv, ObservationNorm, @@ -412,7 +415,7 @@ def test_transform_rb(self, rbclass): assert ((sample["reward"] == 0) | (sample["reward"] == 1)).all() def test_transform_inverse(self): - raise pytest.skip("No inverse for BinerizedReward") + raise pytest.skip("No inverse for BinarizedReward") class TestClipTransform(TransformBase): @@ -12403,6 +12406,332 @@ def test_transform_inverse(self): pytest.skip("Tested elsewhere") +class TestLineariseRewards(TransformBase): + def test_weight_shape_error(self): + with pytest.raises( + ValueError, match="Expected weights to be a unidimensional tensor" + ): + LineariseRewards(in_keys=("reward",), weights=torch.ones(size=(2, 4))) + + def test_weight_sign_error(self): + with pytest.raises(ValueError, match="Expected all weights to be >0"): + LineariseRewards(in_keys=("reward",), weights=-torch.ones(size=(2,))) + + def test_discrete_spec_error(self): + with pytest.raises( + NotImplementedError, + match="Aggregation of rewards that take discrete values is not supported.", + ): + transform = LineariseRewards(in_keys=("reward",)) + reward_spec = Categorical(n=2) + transform.transform_reward_spec(reward_spec) + + @pytest.mark.parametrize( + "reward_spec", + [ + UnboundedContinuous(shape=3), + BoundedContinuous(0, 1, shape=2), + ], + ) + def test_single_trans_env_check(self, reward_spec: TensorSpec): + env = TransformedEnv( + ContinuousActionVecMockEnv(reward_spec=reward_spec), + LineariseRewards(in_keys=["reward"]), # will use default weights + ) + check_env_specs(env) + + @pytest.mark.parametrize( + "reward_spec", + [ + UnboundedContinuous(shape=3), + BoundedContinuous(0, 1, shape=2), + ], + ) + def test_serial_trans_env_check(self, reward_spec: TensorSpec): + def make_env(): + return TransformedEnv( + ContinuousActionVecMockEnv(reward_spec=reward_spec), + LineariseRewards(in_keys=["reward"]), # will use default weights + ) + + env = SerialEnv(2, make_env) + check_env_specs(env) + + @pytest.mark.parametrize( + "reward_spec", + [ + UnboundedContinuous(shape=3), + BoundedContinuous(0, 1, shape=2), + ], + ) + def test_parallel_trans_env_check( + self, maybe_fork_ParallelEnv, reward_spec: TensorSpec + ): + def make_env(): + return TransformedEnv( + ContinuousActionVecMockEnv(reward_spec=reward_spec), + LineariseRewards(in_keys=["reward"]), # will use default weights + ) + + env = maybe_fork_ParallelEnv(2, make_env) + try: + check_env_specs(env) + finally: + try: + env.close() + except RuntimeError: + pass + + @pytest.mark.parametrize( + "reward_spec", + [ + UnboundedContinuous(shape=3), + BoundedContinuous(0, 1, shape=2), + ], + ) + def test_trans_serial_env_check(self, reward_spec: TensorSpec): + def make_env(): + return ContinuousActionVecMockEnv(reward_spec=reward_spec) + + env = TransformedEnv( + SerialEnv(2, make_env), LineariseRewards(in_keys=["reward"]) + ) + check_env_specs(env) + + @pytest.mark.parametrize( + "reward_spec", + [ + UnboundedContinuous(shape=3), + BoundedContinuous(0, 1, shape=2), + ], + ) + def test_trans_parallel_env_check( + self, maybe_fork_ParallelEnv, reward_spec: TensorSpec + ): + def make_env(): + return ContinuousActionVecMockEnv(reward_spec=reward_spec) + + env = TransformedEnv( + maybe_fork_ParallelEnv(2, make_env), + LineariseRewards(in_keys=["reward"]), + ) + try: + check_env_specs(env) + finally: + try: + env.close() + except RuntimeError: + pass + + @pytest.mark.parametrize("reward_key", [("reward",), ("agents", "reward")]) + @pytest.mark.parametrize( + "num_rewards, weights", + [ + (1, None), + (3, None), + (2, [1.0, 2.0]), + ], + ) + def test_transform_no_env(self, reward_key, num_rewards, weights): + out_keys = reward_key[:-1] + ("scalar_reward",) + t = LineariseRewards(in_keys=[reward_key], out_keys=[out_keys], weights=weights) + td = TensorDict({reward_key: torch.randn(num_rewards)}, []) + t._call(td) + + weights = torch.ones(num_rewards) if weights is None else torch.tensor(weights) + expected = sum( + w * r + for w, r in zip( + weights, + td[reward_key], + ) + ) + torch.testing.assert_close(td[out_keys], expected) + + @pytest.mark.parametrize("reward_key", [("reward",), ("agents", "reward")]) + @pytest.mark.parametrize( + "num_rewards, weights", + [ + (1, None), + (3, None), + (2, [1.0, 2.0]), + ], + ) + def test_transform_compose(self, reward_key, num_rewards, weights): + out_keys = reward_key[:-1] + ("scalar_reward",) + t = Compose( + LineariseRewards(in_keys=[reward_key], out_keys=[out_keys], weights=weights) + ) + td = TensorDict({reward_key: torch.randn(num_rewards)}, []) + t._call(td) + + weights = torch.ones(num_rewards) if weights is None else torch.tensor(weights) + expected = sum( + w * r + for w, r in zip( + weights, + td[reward_key], + ) + ) + torch.testing.assert_close(td[out_keys], expected) + + class _DummyMultiObjectiveEnv(EnvBase): + """A dummy multi-objective environment.""" + + def __init__(self, num_rewards: int) -> None: + super().__init__() + self._num_rewards = num_rewards + + self.observation_spec = Composite( + observation=UnboundedContinuous((*self.batch_size, 3)) + ) + self.action_spec = Categorical(2, (*self.batch_size, 1), dtype=torch.bool) + self.done_spec = Categorical(2, (*self.batch_size, 1), dtype=torch.bool) + self.full_done_spec["truncated"] = self.full_done_spec["terminated"].clone() + self.reward_spec = UnboundedContinuous(*self.batch_size, num_rewards) + + def _reset(self, tensordict: TensorDict) -> TensorDict: + return self.observation_spec.sample() + + def _step(self, tensordict: TensorDict) -> TensorDict: + done, terminated = False, False + reward = torch.randn((self._num_rewards,)) + + return TensorDict( + { + ("observation"): self.observation_spec["observation"].sample(), + ("done"): done, + ("terminated"): terminated, + ("reward"): reward, + } + ) + + def _set_seed(self) -> None: + pass + + @pytest.mark.parametrize( + "num_rewards, weights", + [ + (1, None), + (3, None), + (2, [1.0, 2.0]), + ], + ) + def test_transform_env(self, num_rewards, weights): + weights = weights if weights is not None else [1.0 for _ in range(num_rewards)] + + transform = LineariseRewards( + in_keys=("reward",), out_keys=("scalar_reward",), weights=weights + ) + env = TransformedEnv(self._DummyMultiObjectiveEnv(num_rewards), transform) + rollout = env.rollout(10) + scalar_reward = rollout.get(("next", "scalar_reward")) + assert scalar_reward.shape[-1] == 1 + + expected = sum( + w * r + for w, r in zip(weights, rollout.get(("next", "reward")).split(1, dim=-1)) + ) + torch.testing.assert_close(scalar_reward, expected) + + @pytest.mark.parametrize( + "num_rewards, weights", + [ + (1, None), + (3, None), + (2, [1.0, 2.0]), + ], + ) + def test_transform_model(self, num_rewards, weights): + weights = weights if weights is not None else [1.0 for _ in range(num_rewards)] + transform = LineariseRewards( + in_keys=("reward",), out_keys=("scalar_reward",), weights=weights + ) + + model = nn.Sequential(transform, nn.Identity()) + td = TensorDict({"reward": torch.randn(num_rewards)}, []) + model(td) + + expected = sum(w * r for w, r in zip(weights, td["reward"])) + torch.testing.assert_close(td["scalar_reward"], expected) + + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) + def test_transform_rb(self, rbclass): + num_rewards = 3 + weights = None + transform = LineariseRewards( + in_keys=("reward",), out_keys=("scalar_reward",), weights=weights + ) + + rb = rbclass(storage=LazyTensorStorage(10)) + td = TensorDict({"reward": torch.randn(num_rewards)}, []).expand(10) + rb.append_transform(transform) + rb.extend(td) + + td = rb.sample(2) + torch.testing.assert_close(td["scalar_reward"], td["reward"].sum(-1)) + + def test_transform_inverse(self): + raise pytest.skip("No inverse for LineariseReward") + + @pytest.mark.parametrize( + "weights, reward_spec, expected_spec", + [ + (None, UnboundedContinuous(shape=3), UnboundedContinuous(shape=1)), + ( + None, + BoundedContinuous(0, 1, shape=3), + BoundedContinuous(0, 3, shape=1), + ), + ( + None, + BoundedContinuous(low=[-1.0, -2.0], high=[1.0, 2.0]), + BoundedContinuous(low=-3.0, high=3.0, shape=1), + ), + ( + [1.0, 0.0], + BoundedContinuous( + low=[-1.0, -2.0], + high=[1.0, 2.0], + shape=2, + ), + BoundedContinuous(low=-1.0, high=1.0, shape=1), + ), + ], + ) + def test_reward_spec( + self, + weights, + reward_spec: TensorSpec, + expected_spec: TensorSpec, + ) -> None: + transform = LineariseRewards(in_keys=("reward",), weights=weights) + assert transform.transform_reward_spec(reward_spec) == expected_spec + + def test_composite_reward_spec(self) -> None: + weights = None + reward_spec = Composite( + agent_0=Composite( + reward=BoundedContinuous(low=[0, 0, 0], high=[1, 1, 1], shape=3) + ), + agent_1=Composite( + reward=BoundedContinuous( + low=[-1, -1, -1], + high=[1, 1, 1], + shape=3, + ) + ), + ) + expected_reward_spec = Composite( + agent_0=Composite(reward=BoundedContinuous(low=0, high=3, shape=1)), + agent_1=Composite(reward=BoundedContinuous(low=-3, high=3, shape=1)), + ) + transform = LineariseRewards( + in_keys=[("agent_0", "reward"), ("agent_1", "reward")], weights=weights + ) + assert transform.transform_reward_spec(reward_spec) == expected_reward_spec + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 639ed820e86..3ed65d59d16 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -72,6 +72,7 @@ Binary, BinaryDiscreteTensorSpec, Bounded, + BoundedContinuous, BoundedTensorSpec, Categorical, Composite, diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index b863ad0801c..bcb50899549 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -69,6 +69,7 @@ gSDENoise, InitTracker, KLRewardTransform, + LineariseRewards, MultiStepTransform, NoopResetEnv, ObservationNorm, diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 77f6ecc03bf..9e261eee8f2 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -32,6 +32,7 @@ GrayScale, gSDENoise, InitTracker, + LineariseRewards, NoopResetEnv, ObservationNorm, ObservationTransform, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 8e074fa8679..ea4b32d7300 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -63,6 +63,7 @@ from torchrl.data.tensor_specs import ( Binary, Bounded, + BoundedContinuous, Categorical, Composite, ContinuousBox, @@ -71,6 +72,7 @@ OneHot, TensorSpec, Unbounded, + UnboundedContinuous, ) from torchrl.envs.common import _do_nothing, _EnvPostInit, EnvBase, make_tensordict from torchrl.envs.transforms import functional as F @@ -4234,7 +4236,7 @@ class CatTensors(Transform): del_keys (bool, optional): if ``True``, the input values will be deleted after concatenation. Default is ``True``. unsqueeze_if_oor (bool, optional): if ``True``, CatTensor will check that - the dimension indicated exist for the tensors to concatenate. If not, + the indicated dimension exists for the tensors to concatenate. If not, the tensors will be unsqueezed along that dimension. Default is ``False``. sort (bool, optional): if ``True``, the keys will be sorted in the @@ -7709,7 +7711,7 @@ class BurnInTransform(Transform): .. note:: This transform expects as inputs TensorDicts with its last dimension being the - time dimension. It also assumes that all provided modules can process + time dimension. It also assumes that all provided modules can process sequential data. Examples: @@ -9287,3 +9289,118 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite: high=torch.iinfo(torch.int64).max, ) return super().transform_observation_spec(observation_spec) + + +class LineariseRewards(Transform): + """Transforms a multi-objective reward signal to a single-objective one via a weighted sum. + + Args: + in_keys (List[NestedKey]): The keys under which the multi-objective rewards are found. + out_keys (List[NestedKey], optional): The keys under which single-objective rewards should be written. Defaults to :attr:`in_keys`. + weights (List[float], Tensor, optional): Dictates how to weight each reward when summing them. Defaults to `[1.0, 1.0, ...]`. + + .. warning:: + If a sequence of `in_keys` of length strictly greater than one is passed (e.g. one group for each agent in a + multi-agent set-up), the same weights will be applied for each entry. If you need to aggregate rewards + differently for each group, use several :class:`~torchrl.envs.LineariseRewards` in a row. + + Example: + >>> import mo_gymnasium as mo_gym + >>> from torchrl.envs import MOGymWrapper + >>> mo_env = MOGymWrapper(mo_gym.make("deep-sea-treasure-v0")) + >>> mo_env.reward_spec + BoundedContinuous( + shape=torch.Size([2]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, contiguous=True)), + ...) + >>> so_env = TransformedEnv(mo_env, LineariseRewards(in_keys=("reward",))) + >>> so_env.reward_spec + BoundedContinuous( + shape=torch.Size([1]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)), + ...) + >>> td = so_env.rollout(5) + >>> td["next", "reward"].shape + torch.Size([5, 1]) + """ + + def __init__( + self, + in_keys: Sequence[NestedKey], + out_keys: Sequence[NestedKey] | None = None, + *, + weights: Sequence[float] | Tensor | None = None, + ) -> None: + out_keys = in_keys if out_keys is None else out_keys + super().__init__(in_keys=in_keys, out_keys=out_keys) + + if weights is not None: + weights = weights if isinstance(weights, Tensor) else torch.tensor(weights) + + # This transform should only receive vectorial weights (all batch dimensions will be aggregated similarly). + if weights.ndim >= 2: + raise ValueError( + f"Expected weights to be a unidimensional tensor. Got {weights.ndim} dimension." + ) + + # Avoids switching from reward to costs. + if (weights < 0).any(): + raise ValueError(f"Expected all weights to be >0. Got {weights}.") + + self.register_buffer("weights", weights) + else: + self.weights = None + + @_apply_to_composite + def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + if not reward_spec.domain == "continuous": + raise NotImplementedError( + "Aggregation of rewards that take discrete values is not supported." + ) + + *batch_size, num_rewards = reward_spec.shape + weights = ( + torch.ones(num_rewards, device=reward_spec.device, dtype=reward_spec.dtype) + if self.weights is None + else self.weights + ) + + num_weights = torch.numel(weights) + if num_weights != num_rewards: + raise ValueError( + "The number of rewards and weights should match. " + f"Got: {num_rewards} and {num_weights}" + ) + + if isinstance(reward_spec, UnboundedContinuous): + reward_spec.shape = torch.Size([*batch_size, 1]) + return reward_spec + + # The lines below are correct only if all weights are positive. + low = (weights * reward_spec.space.low).sum(dim=-1, keepdim=True) + high = (weights * reward_spec.space.high).sum(dim=-1, keepdim=True) + + return BoundedContinuous( + low=low, + high=high, + device=reward_spec.device, + dtype=reward_spec.dtype, + ) + + def _apply_transform(self, reward: Tensor) -> TensorDictBase: + if self.weights is None: + return reward.sum(dim=-1) + + *batch_size, num_rewards = reward.shape + num_weights = torch.numel(self.weights) + if num_weights != num_rewards: + raise ValueError( + "The number of rewards and weights should match. " + f"Got: {num_rewards} and {num_weights}." + ) + + return (self.weights * reward).sum(dim=-1) From 61e05b3d9a967c0cbbda2e355859287ce7221f52 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 13 Jan 2025 16:47:26 +0000 Subject: [PATCH 2/2] [BugFix, BE] Document and fix fps passing in recorder and loggers ghstack-source-id: b3996a9a27643eb5da8a78135f6b9fcef3685f17 Pull Request resolved: https://github.com/pytorch/rl/pull/2694 --- torchrl/record/loggers/csv.py | 4 +++- torchrl/record/loggers/mlflow.py | 11 +++++++++-- torchrl/record/loggers/wandb.py | 10 ++++++++-- torchrl/record/recorder.py | 16 ++++++++++++---- 4 files changed, 32 insertions(+), 9 deletions(-) diff --git a/torchrl/record/loggers/csv.py b/torchrl/record/loggers/csv.py index a2295bc116a..0052a6149db 100644 --- a/torchrl/record/loggers/csv.py +++ b/torchrl/record/loggers/csv.py @@ -21,7 +21,7 @@ class CSVExperiment: """A CSV logger experiment class.""" - def __init__(self, log_dir: str, *, video_format="pt", video_fps=30): + def __init__(self, log_dir: str, *, video_format="pt", video_fps: int = 30): self.scalars = defaultdict(lambda: []) self.videos_counter = defaultdict(lambda: 0) self.text_counter = defaultdict(lambda: 0) @@ -144,6 +144,8 @@ class CSVLogger(Logger): """ + experiment: CSVExperiment + def __init__( self, exp_name: str, diff --git a/torchrl/record/loggers/mlflow.py b/torchrl/record/loggers/mlflow.py index 304b9b3dbe0..a5c8e39e423 100644 --- a/torchrl/record/loggers/mlflow.py +++ b/torchrl/record/loggers/mlflow.py @@ -24,6 +24,10 @@ class MLFlowLogger(Logger): Args: exp_name (str): The name of the experiment. tracking_uri (str): A tracking URI to a datastore that supports MLFlow or a local directory. + + Keyword Args: + fps (int, optional): Number of frames per second when recording videos. Defaults to ``30``. + """ def __init__( @@ -31,6 +35,8 @@ def __init__( exp_name: str, tracking_uri: str, tags: Optional[Dict[str, Any]] = None, + *, + video_fps: int = 30, **kwargs, ) -> None: import mlflow @@ -43,6 +49,7 @@ def __init__( mlflow.set_tracking_uri(tracking_uri) super().__init__(exp_name=exp_name, log_dir=tracking_uri) self.video_log_counter = 0 + self.video_fps = video_fps def _create_experiment(self) -> "mlflow.ActiveRun": # noqa import mlflow @@ -85,7 +92,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None: video (Tensor): The video to be logged, expected to be in (T, C, H, W) format for consistency with other loggers. **kwargs: Other keyword arguments. By construction, log_video - supports 'step' (integer indicating the step index) and 'fps' (default: 6). + supports 'step' (integer indicating the step index) and 'fps' (defaults to ``self.video_fps``). """ import mlflow import torchvision @@ -103,7 +110,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None: "The MLFlow logger only supports videos with 3 color channels." ) self.video_log_counter += 1 - fps = kwargs.pop("fps", 6) + fps = kwargs.pop("fps", self.video_fps) step = kwargs.pop("step", None) with TemporaryDirectory() as temp_dir: video_name = f"{name}_step_{step:04}.mp4" if step else f"{name}.mp4" diff --git a/torchrl/record/loggers/wandb.py b/torchrl/record/loggers/wandb.py index f0048648f86..c015c2b0214 100644 --- a/torchrl/record/loggers/wandb.py +++ b/torchrl/record/loggers/wandb.py @@ -35,6 +35,9 @@ class WandbLogger(Logger): project (str, optional): The name of the project where you're sending the new run. If the project is not specified, the run is put in an ``"Uncategorized"`` project. + + Keyword Args: + fps (int, optional): Number of frames per second when recording videos. Defaults to ``30``. **kwargs: Extra keyword arguments for ``wandb.init``. See relevant page for more info. @@ -52,6 +55,8 @@ def __init__( save_dir: str = None, id: str = None, project: str = None, + *, + video_fps: int = 32, **kwargs, ) -> None: if not _has_wandb: @@ -68,6 +73,7 @@ def __init__( self.save_dir = save_dir self.id = id self.project = project + self.video_fps = video_fps self._wandb_kwargs = { "name": exp_name, "dir": save_dir, @@ -127,7 +133,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None: video (Tensor): The video to be logged. **kwargs: Other keyword arguments. By construction, log_video supports 'step' (integer indicating the step index), 'format' - (default is 'mp4') and 'fps' (default: 6). Other kwargs are + (default is 'mp4') and 'fps' (defaults to ``self.video_fps``). Other kwargs are passed as-is to the :obj:`experiment.log` method. """ import wandb @@ -148,7 +154,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None: "moviepy not found, videos cannot be logged with TensorboardLogger" ) self.video_log_counter += 1 - fps = kwargs.pop("fps", 6) + fps = kwargs.pop("fps", self.video_fps) step = kwargs.pop("step", None) format = kwargs.pop("format", "mp4") if step not in (None, self._prev_video_step, self._prev_video_step + 1): diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index dfab11e5d35..4b800a2d28e 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -49,6 +49,10 @@ class VideoRecorder(ObservationTransform): if not. out_keys (sequence of NestedKey, optional): destination keys. Defaults to ``in_keys`` if not provided. + fps (int, optional): Frames per second of the output video. Defaults to the logger predefined ``fps``, + and overrides it if provided. + **kwargs (Dict[str, Any], optional): additional keyword arguments for + :meth:`~torchrl.record.loggers.Logger.log_video`. Examples: The following example shows how to save a rollout under a video. First a few imports: @@ -81,10 +85,11 @@ class VideoRecorder(ObservationTransform): >>> from torchrl.data.datasets import OpenXExperienceReplay >>> from torchrl.envs import Compose >>> from torchrl.record import VideoRecorder, CSVLogger - >>> # Create a logger that saves videos as mp4 - >>> logger = CSVLogger("./dump", video_format="mp4") + >>> # Create a logger that saves videos as mp4 using 24 frames per sec + >>> logger = CSVLogger("./dump", video_format="mp4", video_fps=24) >>> # We use the VideoRecorder transform to save register the images coming from the batch. - >>> t = VideoRecorder(logger=logger, tag="pixels", in_keys=[("next", "observation", "image")]) + >>> # Setting the fps to 12 overrides the one set in the logger, not doing so keeps it unchanged. + >>> t = VideoRecorder(logger=logger, tag="pixels", in_keys=[("next", "observation", "image")], fps=12) >>> # Each batch of data will have 10 consecutive videos of 200 frames each (maximum, since strict_length=False) >>> dataset = OpenXExperienceReplay("cmu_stretch", batch_size=2000, slice_len=200, ... download=True, strict_length=False, @@ -108,6 +113,7 @@ def __init__( center_crop: Optional[int] = None, make_grid: bool | None = None, out_keys: Optional[Sequence[NestedKey]] = None, + fps: int | None = None, **kwargs, ) -> None: if in_keys is None: @@ -115,8 +121,10 @@ def __init__( if out_keys is None: out_keys = copy(in_keys) super().__init__(in_keys=in_keys, out_keys=out_keys) - video_kwargs = {"fps": 6} + video_kwargs = {} video_kwargs.update(kwargs) + if fps is not None: + self.video_kwargs["fps"] = fps self.video_kwargs = video_kwargs self.iter = 0 self.skip = skip