CompressedListStorage¶
- class torchrl.data.replay_buffers.CompressedListStorage(max_size: int, *, compression_fn: Callable | None = None, decompression_fn: Callable | None = None, compression_level: int = 3, device: torch.device = 'cpu', compilable: bool = False)[原始碼]¶
一個壓縮和解壓縮資料的儲存。
此儲存在儲存時壓縮資料,在檢索時解壓縮。它特別適用於儲存可以被顯著壓縮以節省記憶體的原始感官觀察(如影像)。
- 引數:
max_size (int) – 儲存大小,即緩衝區中儲存的最大元素數量。
compression_fn (callable, optional) – 用於壓縮資料的函式。應接受一個張量並返回一個壓縮後的位元組張量。預設為 zstd 壓縮。
decompression_fn (callable, optional) – 用於解壓縮資料的函式。應接受一個壓縮後的位元組張量並返回原始張量。預設為 zstd 解壓縮。
compression_level (int, optional) – 使用預設壓縮函式時,壓縮級別(zstd 為 1-22)。預設為 3。
device (torch.device, optional) – 儲存和傳送取樣張量的裝置。預設為
torch.device("cpu")。compilable (bool, optional) – 儲存是否可編譯。如果為
True,則寫入器不能在多個程序之間共享。預設為False。
示例
>>> import torch >>> from torchrl.data import CompressedListStorage, ReplayBuffer >>> from tensordict import TensorDict >>> >>> # Create a compressed storage for image data >>> storage = CompressedListStorage(max_size=1000, compression_level=3) >>> rb = ReplayBuffer(storage=storage, batch_size=5) >>> >>> # Add some image data >>> images = torch.randn(10, 3, 84, 84) # Atari-like frames >>> data = TensorDict({"obs": images}, batch_size=[10]) >>> rb.extend(data) >>> >>> # Sample and verify data is decompressed correctly >>> sample = rb.sample(3) >>> print(sample["obs"].shape) # torch.Size([3, 3, 84, 84])
- attach(buffer: Any) None¶
此函式將取樣器附加到此儲存。
從該儲存讀取的緩衝區必須透過呼叫此方法作為已附加實體包含進來。這確保了當儲存中的資料發生變化時,元件能夠感知到這些變化,即使該儲存與其他緩衝區(例如,Priority Samplers)共享。
- 引數:
buffer – 讀取此儲存的物件。
- dump(*args, **kwargs)¶
dumps()的別名。
- load(*args, **kwargs)¶
loads()的別名。
- save(*args, **kwargs)¶
dumps()的別名。
- to_bytestream(data_to_bytestream: torch.Tensor | np.array | Any) bytes[原始碼]¶
將資料轉換為位元組流。