Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 13, 2025
1 parent 2f5fc5b commit a01b2ad
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 9 deletions.
4 changes: 3 additions & 1 deletion torchrl/record/loggers/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class CSVExperiment:
"""A CSV logger experiment class."""

def __init__(self, log_dir: str, *, video_format="pt", video_fps=30):
def __init__(self, log_dir: str, *, video_format="pt", video_fps: int = 30):
self.scalars = defaultdict(lambda: [])
self.videos_counter = defaultdict(lambda: 0)
self.text_counter = defaultdict(lambda: 0)
Expand Down Expand Up @@ -144,6 +144,8 @@ class CSVLogger(Logger):
"""

experiment: CSVExperiment

def __init__(
self,
exp_name: str,
Expand Down
11 changes: 9 additions & 2 deletions torchrl/record/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,19 @@ class MLFlowLogger(Logger):
Args:
exp_name (str): The name of the experiment.
tracking_uri (str): A tracking URI to a datastore that supports MLFlow or a local directory.
Keyword Args:
fps (int, optional): Number of frames per second when recording videos. Defaults to ``30``.
"""

def __init__(
self,
exp_name: str,
tracking_uri: str,
tags: Optional[Dict[str, Any]] = None,
*,
video_fps: int = 30,
**kwargs,
) -> None:
import mlflow
Expand All @@ -43,6 +49,7 @@ def __init__(
mlflow.set_tracking_uri(tracking_uri)
super().__init__(exp_name=exp_name, log_dir=tracking_uri)
self.video_log_counter = 0
self.video_fps = video_fps

def _create_experiment(self) -> "mlflow.ActiveRun": # noqa
import mlflow
Expand Down Expand Up @@ -85,7 +92,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None:
video (Tensor): The video to be logged, expected to be in (T, C, H, W) format
for consistency with other loggers.
**kwargs: Other keyword arguments. By construction, log_video
supports 'step' (integer indicating the step index) and 'fps' (default: 6).
supports 'step' (integer indicating the step index) and 'fps' (defaults to ``self.video_fps``).
"""
import mlflow
import torchvision
Expand All @@ -103,7 +110,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None:
"The MLFlow logger only supports videos with 3 color channels."
)
self.video_log_counter += 1
fps = kwargs.pop("fps", 6)
fps = kwargs.pop("fps", self.video_fps)
step = kwargs.pop("step", None)
with TemporaryDirectory() as temp_dir:
video_name = f"{name}_step_{step:04}.mp4" if step else f"{name}.mp4"
Expand Down
10 changes: 8 additions & 2 deletions torchrl/record/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class WandbLogger(Logger):
project (str, optional): The name of the project where you're sending
the new run. If the project is not specified, the run is put in
an ``"Uncategorized"`` project.
Keyword Args:
fps (int, optional): Number of frames per second when recording videos. Defaults to ``30``.
**kwargs: Extra keyword arguments for ``wandb.init``. See relevant page for
more info.
Expand All @@ -52,6 +55,8 @@ def __init__(
save_dir: str = None,
id: str = None,
project: str = None,
*,
video_fps: int = 32,
**kwargs,
) -> None:
if not _has_wandb:
Expand All @@ -68,6 +73,7 @@ def __init__(
self.save_dir = save_dir
self.id = id
self.project = project
self.video_fps = video_fps
self._wandb_kwargs = {
"name": exp_name,
"dir": save_dir,
Expand Down Expand Up @@ -127,7 +133,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None:
video (Tensor): The video to be logged.
**kwargs: Other keyword arguments. By construction, log_video
supports 'step' (integer indicating the step index), 'format'
(default is 'mp4') and 'fps' (default: 6). Other kwargs are
(default is 'mp4') and 'fps' (defaults to ``self.video_fps``). Other kwargs are
passed as-is to the :obj:`experiment.log` method.
"""
import wandb
Expand All @@ -148,7 +154,7 @@ def log_video(self, name: str, video: Tensor, **kwargs) -> None:
"moviepy not found, videos cannot be logged with TensorboardLogger"
)
self.video_log_counter += 1
fps = kwargs.pop("fps", 6)
fps = kwargs.pop("fps", self.video_fps)
step = kwargs.pop("step", None)
format = kwargs.pop("format", "mp4")
if step not in (None, self._prev_video_step, self._prev_video_step + 1):
Expand Down
16 changes: 12 additions & 4 deletions torchrl/record/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ class VideoRecorder(ObservationTransform):
if not.
out_keys (sequence of NestedKey, optional): destination keys. Defaults
to ``in_keys`` if not provided.
fps (int, optional): Frames per second of the output video. Defaults to the logger predefined ``fps``,
and overrides it if provided.
**kwargs (Dict[str, Any], optional): additional keyword arguments for
:meth:`~torchrl.record.loggers.Logger.log_video`.
Examples:
The following example shows how to save a rollout under a video. First a few imports:
Expand Down Expand Up @@ -81,10 +85,11 @@ class VideoRecorder(ObservationTransform):
>>> from torchrl.data.datasets import OpenXExperienceReplay
>>> from torchrl.envs import Compose
>>> from torchrl.record import VideoRecorder, CSVLogger
>>> # Create a logger that saves videos as mp4
>>> logger = CSVLogger("./dump", video_format="mp4")
>>> # Create a logger that saves videos as mp4 using 24 frames per sec
>>> logger = CSVLogger("./dump", video_format="mp4", video_fps=24)
>>> # We use the VideoRecorder transform to save register the images coming from the batch.
>>> t = VideoRecorder(logger=logger, tag="pixels", in_keys=[("next", "observation", "image")])
>>> # Setting the fps to 12 overrides the one set in the logger, not doing so keeps it unchanged.
>>> t = VideoRecorder(logger=logger, tag="pixels", in_keys=[("next", "observation", "image")], fps=12)
>>> # Each batch of data will have 10 consecutive videos of 200 frames each (maximum, since strict_length=False)
>>> dataset = OpenXExperienceReplay("cmu_stretch", batch_size=2000, slice_len=200,
... download=True, strict_length=False,
Expand All @@ -108,15 +113,18 @@ def __init__(
center_crop: Optional[int] = None,
make_grid: bool | None = None,
out_keys: Optional[Sequence[NestedKey]] = None,
fps: int | None = None,
**kwargs,
) -> None:
if in_keys is None:
in_keys = ["pixels"]
if out_keys is None:
out_keys = copy(in_keys)
super().__init__(in_keys=in_keys, out_keys=out_keys)
video_kwargs = {"fps": 6}
video_kwargs = {}
video_kwargs.update(kwargs)
if fps is not None:
self.video_kwargs["fps"] = fps
self.video_kwargs = video_kwargs
self.iter = 0
self.skip = skip
Expand Down

0 comments on commit a01b2ad

Please sign in to comment.