VideoRecorder¶
- torchrl.record.VideoRecorder(logger: Logger, tag: str, in_keys: Sequence[NestedKey] | None = None, skip: int | None = None, center_crop: int | None = None, make_grid: bool | None = None, out_keys: Sequence[NestedKey] | None = None, fps: int | None = None, **kwargs) None[原始碼]¶
影片錄製器轉換。
當需要時,將從環境中記錄一系列觀測,並將它們寫入 Logger 物件。
- 引數:
logger (Logger) – 一個 Logger 例項,影片應被寫入其中。要將影片儲存為 memmap 張量或 mp4 檔案,請使用
CSVLogger類。tag (str) – logger 中的影片標籤。
in_keys (Sequence of NestedKey, optional) – 用於生成影片的讀取鍵。預設為
"pixels"。skip (int) – 輸出影片中的幀間隔。如果變換有父環境,則預設為
2,如果沒有,則預設為1。center_crop (int, optional) – 方形中心裁剪的值。
make_grid (bool, optional) – 如果為
True,則建立一個網格,假設提供形狀為 [B x W x H x 3] 的張量,其中 B 是批次大小。如果變換有父環境,則預設為True,如果不是,則預設為False。out_keys (sequence of NestedKey, optional) – 目標鍵。如果未提供,則預設為
in_keys。fps (int, optional) – 輸出影片的每秒幀數。預設為 logger 預定義的
fps,如果提供則覆蓋它。**kwargs (Dict[str, Any], optional) –
log_video()的附加關鍵字引數。
示例
以下示例展示瞭如何將一個 rollout 儲存為影片。首先匯入一些內容。
>>> from torchrl.record import VideoRecorder >>> from torchrl.record.loggers.csv import CSVLogger >>> from torchrl.envs import TransformedEnv, DMControlEnv
影片格式在 logger 中選擇。Wandb 和 tensorboard 會自行處理。CSV 支援各種影片格式。
>>> logger = CSVLogger(exp_name="cheetah", log_dir="cheetah_videos", video_format="mp4")
某些環境(例如 Atari 遊戲)會原生返回影像,而另一些則需要使用者主動請求。請參閱
GymEnv或DMControlEnv,瞭解在這些情況下如何渲染影像。>>> base_env = DMControlEnv("cheetah", "run", from_pixels=True) >>> env = TransformedEnv(base_env, VideoRecorder(logger=logger, tag="run_video")) >>> env.rollout(100)
所有變換都有一個 dump 函式,除了
VideoRecorder和Compose(它會將 dumps 分發給所有成員)之外,大多數都只是一個 no-op。>>> env.transform.dump()
該變換也可以在資料集內使用,以儲存收集到的影片。與環境中的情況不同,影像將以批次形式出現。
skip引數將允許僅在特定間隔儲存影像。>>> from torchrl.data.datasets import OpenXExperienceReplay >>> from torchrl.envs import Compose >>> from torchrl.record import VideoRecorder, CSVLogger >>> # Create a logger that saves videos as mp4 using 24 frames per sec >>> logger = CSVLogger("./dump", video_format="mp4", video_fps=24) >>> # We use the VideoRecorder transform to save register the images coming from the batch. >>> # Setting the fps to 12 overrides the one set in the logger, not doing so keeps it unchanged. >>> t = VideoRecorder(logger=logger, tag="pixels", in_keys=[("next", "observation", "image")], fps=12) >>> # Each batch of data will have 10 consecutive videos of 200 frames each (maximum, since strict_length=False) >>> dataset = OpenXExperienceReplay("cmu_stretch", batch_size=2000, slice_len=200, ... download=True, strict_length=False, ... transform=t) >>> # Get a batch of data and visualize it >>> for data in dataset: ... t.dump() ... break
我們的影片可在
./cheetah_videos/cheetah/videos/run_video_0.mp4下找到!