From 61e05b3d9a967c0cbbda2e355859287ce7221f52 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 13 Jan 2025 16:47:26 +0000 Subject: [PATCH] [BugFix, BE] Document and fix fps passing in recorder and loggers ghstack-source-id: b3996a9a27643eb5da8a78135f6b9fcef3685f17 Pull Request resolved: https://github.com/pytorch/rl/pull/2694 --- torchrl/record/loggers/csv.py | 4 +++- torchrl/record/loggers/mlflow.py | 11 +++++++++-- torchrl/record/loggers/wandb.py | 10 ++++++++-- torchrl/record/recorder.py | 16 ++++++++++++---- 4 files changed, 32 insertions(+), 9 deletions(-) diff --git a/torchrl/record/loggers/csv.py b/torchrl/record/loggers/csv.py index a2295bc116a..0052a6149db 100644 --- a/torchrl/record/loggers/csv.py +++ b/torchrl/record/loggers/csv.py @@ -21,7 +21,7 @@ class CSVExperiment: """A CSV logger experiment class.""" - def __init__(self, log_dir: str, *, video_format="pt", video_fps=30): + def __init__(self, log_dir: str, *, video_format="pt", video_fps: int = 30): self.scalars = defaultdict(lambda: []) self.videos_counter = defaultdict(lambda: 0) self.text_counter = defaultdict(lambda: 0) @@ -144,6 +144,8 @@ class CSVLogger(Logger): """ + experiment: CSVExperiment + def __init__( self, exp_name: str, diff --git a/torchrl/record/loggers/mlflow.py b/torchrl/record/loggers/mlflow.py index 304b9b3dbe0..a5c8e39e423 100644 --- a/torchrl/record/loggers/mlflow.py +++ b/torchrl/record/loggers/mlflow.py @@ -24,6 +24,10 @@ class MLFlowLogger(Logger): Args: exp_name (str): The name of the experiment. tracking_uri (str): A tracking URI to a datastore that supports MLFlow or a local directory. + + Keyword Args: + fps (int, optional): Number of frames per second when recording videos. Defaults to ``30``. + """ def __init__( @@ -31,6 +35,8 @@ def __init__( exp_name: str, tracking_uri: str, tags: Optional[Dict[str, Any]] = None, + *, + video_fps: int = 30, **kwargs, ) -> None: import mlflow @@ -43,6 +49,7 @@ def __init__( mlflow.set_tracking_uri(tracking_uri) super().__init__(exp_name=exp_name, log_dir=tracking_uri) self.video_log_counter = 0 + self.video_fps = video_fps def _create_experiment(self) -> "mlflow.ActiveRun": # noqa import mlflow @@ -85,7 +92,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None: video (Tensor): The video to be logged, expected to be in (T, C, H, W) format for consistency with other loggers. **kwargs: Other keyword arguments. By construction, log_video - supports 'step' (integer indicating the step index) and 'fps' (default: 6). + supports 'step' (integer indicating the step index) and 'fps' (defaults to ``self.video_fps``). """ import mlflow import torchvision @@ -103,7 +110,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None: "The MLFlow logger only supports videos with 3 color channels." ) self.video_log_counter += 1 - fps = kwargs.pop("fps", 6) + fps = kwargs.pop("fps", self.video_fps) step = kwargs.pop("step", None) with TemporaryDirectory() as temp_dir: video_name = f"{name}_step_{step:04}.mp4" if step else f"{name}.mp4" diff --git a/torchrl/record/loggers/wandb.py b/torchrl/record/loggers/wandb.py index f0048648f86..c015c2b0214 100644 --- a/torchrl/record/loggers/wandb.py +++ b/torchrl/record/loggers/wandb.py @@ -35,6 +35,9 @@ class WandbLogger(Logger): project (str, optional): The name of the project where you're sending the new run. If the project is not specified, the run is put in an ``"Uncategorized"`` project. + + Keyword Args: + fps (int, optional): Number of frames per second when recording videos. Defaults to ``30``. **kwargs: Extra keyword arguments for ``wandb.init``. See relevant page for more info. @@ -52,6 +55,8 @@ def __init__( save_dir: str = None, id: str = None, project: str = None, + *, + video_fps: int = 32, **kwargs, ) -> None: if not _has_wandb: @@ -68,6 +73,7 @@ def __init__( self.save_dir = save_dir self.id = id self.project = project + self.video_fps = video_fps self._wandb_kwargs = { "name": exp_name, "dir": save_dir, @@ -127,7 +133,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None: video (Tensor): The video to be logged. **kwargs: Other keyword arguments. By construction, log_video supports 'step' (integer indicating the step index), 'format' - (default is 'mp4') and 'fps' (default: 6). Other kwargs are + (default is 'mp4') and 'fps' (defaults to ``self.video_fps``). Other kwargs are passed as-is to the :obj:`experiment.log` method. """ import wandb @@ -148,7 +154,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None: "moviepy not found, videos cannot be logged with TensorboardLogger" ) self.video_log_counter += 1 - fps = kwargs.pop("fps", 6) + fps = kwargs.pop("fps", self.video_fps) step = kwargs.pop("step", None) format = kwargs.pop("format", "mp4") if step not in (None, self._prev_video_step, self._prev_video_step + 1): diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index dfab11e5d35..4b800a2d28e 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -49,6 +49,10 @@ class VideoRecorder(ObservationTransform): if not. out_keys (sequence of NestedKey, optional): destination keys. Defaults to ``in_keys`` if not provided. + fps (int, optional): Frames per second of the output video. Defaults to the logger predefined ``fps``, + and overrides it if provided. + **kwargs (Dict[str, Any], optional): additional keyword arguments for + :meth:`~torchrl.record.loggers.Logger.log_video`. Examples: The following example shows how to save a rollout under a video. First a few imports: @@ -81,10 +85,11 @@ class VideoRecorder(ObservationTransform): >>> from torchrl.data.datasets import OpenXExperienceReplay >>> from torchrl.envs import Compose >>> from torchrl.record import VideoRecorder, CSVLogger - >>> # Create a logger that saves videos as mp4 - >>> logger = CSVLogger("./dump", video_format="mp4") + >>> # Create a logger that saves videos as mp4 using 24 frames per sec + >>> logger = CSVLogger("./dump", video_format="mp4", video_fps=24) >>> # We use the VideoRecorder transform to save register the images coming from the batch. - >>> t = VideoRecorder(logger=logger, tag="pixels", in_keys=[("next", "observation", "image")]) + >>> # Setting the fps to 12 overrides the one set in the logger, not doing so keeps it unchanged. + >>> t = VideoRecorder(logger=logger, tag="pixels", in_keys=[("next", "observation", "image")], fps=12) >>> # Each batch of data will have 10 consecutive videos of 200 frames each (maximum, since strict_length=False) >>> dataset = OpenXExperienceReplay("cmu_stretch", batch_size=2000, slice_len=200, ... download=True, strict_length=False, @@ -108,6 +113,7 @@ def __init__( center_crop: Optional[int] = None, make_grid: bool | None = None, out_keys: Optional[Sequence[NestedKey]] = None, + fps: int | None = None, **kwargs, ) -> None: if in_keys is None: @@ -115,8 +121,10 @@ def __init__( if out_keys is None: out_keys = copy(in_keys) super().__init__(in_keys=in_keys, out_keys=out_keys) - video_kwargs = {"fps": 6} + video_kwargs = {} video_kwargs.update(kwargs) + if fps is not None: + self.video_kwargs["fps"] = fps self.video_kwargs = video_kwargs self.iter = 0 self.skip = skip