CatFrames¶
- class torchrl.envs.transforms.CatFrames(N: int, dim: int, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, padding='same', padding_value=0, as_inverse=False, reset_key: NestedKey | None = None, done_key: NestedKey | None = None)[原始碼]¶
將連續的觀察幀連線成一個單一的張量。
此轉換有助於在觀察到的特徵中建立運動或速度感。它也可以與需要訪問過去觀察的模型(如 transformers 等)一起使用。它最初在“Playing Atari with Deep Reinforcement Learning”中提出(https://arxiv.org/pdf/1312.5602.pdf)。
當在轉換後的環境中用作轉換器時,
CatFrames是一個有狀態的類,可以透過呼叫reset()方法將其重置為其初始狀態。此方法接受帶有"_reset"條目的 tensordicts,該條目指示要重置的緩衝區。- 引數:
N (int) – 要連線的觀察次數。
dim (int) – 連線觀察的維度。應為負數,以確保其與不同 batch_size 的環境相容。
in_keys (NestedKey 序列, 可選) – 指向需要連線的幀的鍵。預設為 [“pixels”]。
out_keys (NestedKey 序列, 可選) – 指向輸出寫入位置的鍵。預設為 in_keys 的值。
padding (str, 可選) – 填充方法。可以是
"same"或"constant"。預設為"same",即第一個值用於填充。padding_value (
float, 可選) – 如果padding="constant",則用於填充的值。預設為 0。as_inverse (bool, 可選) – 如果為
True,則轉換作為逆轉換應用。預設為False。reset_key (NestedKey, 可選) – 要用作部分重置指示器的重置鍵。必須是唯一的。如果未提供,則預設為父環境的唯一重置鍵(如果只有一個),否則引發異常。
done_key (NestedKey, 可選) – 要用作部分完成指示器的完成鍵。必須是唯一的。如果未提供,則預設為
"done"。
示例
>>> from torchrl.envs.libs.gym import GymEnv >>> env = TransformedEnv(GymEnv('Pendulum-v1'), ... Compose( ... UnsqueezeTransform(-1, in_keys=["observation"]), ... CatFrames(N=4, dim=-1, in_keys=["observation"]), ... ) ... ) >>> print(env.rollout(3))
CatFrames 轉換器也可以離線使用,以不同比例重現線上幀連線的效果(或為了限制記憶體消耗)。下面的示例以及
torchrl.data.ReplayBuffer的用法給出了完整的圖景。示例
>>> from torchrl.envs.utils import RandomPolicy >>> from torchrl.envs import UnsqueezeTransform, CatFrames >>> from torchrl.collectors import SyncDataCollector >>> # Create a transformed environment with CatFrames: notice the usage of UnsqueezeTransform to create an extra dimension >>> env = TransformedEnv( ... GymEnv("CartPole-v1", from_pixels=True), ... Compose( ... ToTensorImage(in_keys=["pixels"], out_keys=["pixels_trsf"]), ... Resize(in_keys=["pixels_trsf"], w=64, h=64), ... GrayScale(in_keys=["pixels_trsf"]), ... UnsqueezeTransform(-4, in_keys=["pixels_trsf"]), ... CatFrames(dim=-4, N=4, in_keys=["pixels_trsf"]), ... ) ... ) >>> # we design a collector >>> collector = SyncDataCollector( ... env, ... RandomPolicy(env.action_spec), ... frames_per_batch=10, ... total_frames=1000, ... ) >>> for data in collector: ... print(data) ... break >>> # now let's create a transform for the replay buffer. We don't need to unsqueeze the data here. >>> # however, we need to point to both the pixel entry at the root and at the next levels: >>> t = Compose( ... ToTensorImage(in_keys=["pixels", ("next", "pixels")], out_keys=["pixels_trsf", ("next", "pixels_trsf")]), ... Resize(in_keys=["pixels_trsf", ("next", "pixels_trsf")], w=64, h=64), ... GrayScale(in_keys=["pixels_trsf", ("next", "pixels_trsf")]), ... CatFrames(dim=-4, N=4, in_keys=["pixels_trsf", ("next", "pixels_trsf")]), ... ) >>> from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage >>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(1000), transform=t, batch_size=16) >>> data_exclude = data.exclude("pixels_trsf", ("next", "pixels_trsf")) >>> rb.add(data_exclude) >>> s = rb.sample(1) # the buffer has only one element >>> # let's check that our sample is the same as the batch collected during inference >>> assert (data.exclude("collector")==s.squeeze(0).exclude("index", "collector")).all()
注意
CatFrames目前僅支援根目錄下的"done"訊號。巢狀的done,例如在 MARL 設定中找到的,目前不支援。如果需要此功能,請在 TorchRL 儲存庫上提交一個 issue。注意
在回放緩衝區中儲存幀堆疊會顯著增加記憶體消耗(增加 N 倍)。為了緩解這個問題,您可以直接將軌跡儲存在回放緩衝區中,並在取樣時應用
CatFrames。這種方法涉及取樣儲存的軌跡的切片,然後應用幀堆疊轉換。為了方便起見,CatFrames提供了一個make_rb_transform_and_sampler()方法,該方法建立:一個適合在回放緩衝區中使用的轉換器的修改版本
一個對應的
SliceSampler以便與緩衝區一起使用
- forward(tensordict: TensorDictBase) TensorDictBase[原始碼]¶
讀取輸入 tensordict,並對選定的鍵應用轉換。
預設情況下,此方法
直接呼叫
_apply_transform()。不呼叫
_step()或_call()。
此方法不會在任何時候在 env.step 中呼叫。但是,它會在
sample()中呼叫。注意
forward也可以使用dispatch將引數名稱轉換為鍵,並使用常規關鍵字引數。示例
>>> class TransformThatMeasuresBytes(Transform): ... '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.''' ... def __init__(self): ... super().__init__(in_keys=[], out_keys=["bytes"]) ... ... def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ... bytes_in_td = tensordict.bytes() ... tensordict["bytes"] = bytes ... return tensordict >>> t = TransformThatMeasuresBytes() >>> env = env.append_transform(t) # works within envs >>> t(TensorDict(a=0)) # Works offline too.
- make_rb_transform_and_sampler(batch_size: int, **sampler_kwargs) tuple[Transform, torchrl.data.replay_buffers.SliceSampler][原始碼]¶
建立一個轉換器和取樣器,用於在儲存幀堆疊資料時與回放緩衝區一起使用。
此方法透過避免在緩衝區中儲存整個幀堆疊來幫助減少儲存資料中的冗餘。相反,它建立了一個在取樣過程中即時堆疊幀的轉換器,以及一個確保正確維護序列長度的取樣器。
- 引數:
batch_size (int) – 取樣器使用的批次大小。
**sampler_kwargs – 傳遞給
SliceSampler建構函式的其他關鍵字引數。
- 返回:
transform (Transform): 一個在取樣過程中即時堆疊幀的轉換器。
sampler (SliceSampler): 一個確保正確維護序列長度的取樣器。
- 返回型別:
一個包含的元組
示例
>>> env = TransformedEnv(...) >>> catframes = CatFrames(N=4, ...) >>> transform, sampler = catframes.make_rb_transform_and_sampler(batch_size=32) >>> rb = ReplayBuffer(..., sampler=sampler, transform=transform)
注意
處理影像時,建議在前面的
ToTensorImage轉換器中使用不同的in_keys和out_keys。這確保了儲存在緩衝區中的張量與它們的處理後的對應物是分開的,而我們不希望儲存這些處理後的對應物。對於非影像資料,請考慮在CatFrames之前插入一個RenameTransform來建立一個將被儲存在緩衝區中的資料副本。注意
將轉換器新增到回放緩衝區時,應注意還要傳遞 CatFrames 前面的轉換器,例如
ToTensorImage或UnsqueezeTransform,以便CatFrames轉換器看到的資料格式與資料收集期間的格式相同。注意
有關更完整的示例,請參閱 torchrl 的 github 儲存庫 examples 資料夾:https://github.com/pytorch/rl/tree/main/examples/replay-buffers/catframes-in-buffer.py
- transform_observation_spec(observation_spec: TensorSpec) TensorSpec[原始碼]¶
轉換觀察規範,使結果規範與轉換對映匹配。
- 引數:
observation_spec (TensorSpec) – 轉換前的規範
- 返回:
轉換後的預期規範