快捷方式

LazyTensorStorage

class torchrl.data.replay_buffers.LazyTensorStorage(max_size: int, *, device: device = 'cpu', ndim: int = 1, compilable: bool = False, consolidated: bool = False)[原始碼]

用於張量和張量字典的預分配張量儲存。

引數:

max_size (int) – 儲存大小,即緩衝區中儲存的最大元素數量。

關鍵字引數:
  • device (torch.device, optional) – 儲存和傳送取樣張量的裝置。預設為 torch.device("cpu")。如果傳入“auto”,則裝置將從傳入的第一個資料批次中自動收集。此選項預設不啟用,以避免意外將資料放置在 GPU 上導致 OOM 問題。

  • ndim (int, optional) – 計算儲存大小時要考慮的維度數。例如,形狀為 [3, 4] 的儲存,如果 ndim=1,則容量為 3;如果 ndim=2,則容量為 12。預設為 1

  • compilable (bool, optional) – 儲存是否可編譯。如果為 True,則寫入器不能在多個程序之間共享。預設為 False

  • consolidated (bool, optional) – 如果為 True,則儲存將在首次擴充套件後進行合併。預設為 False

示例

>>> data = TensorDict({
...     "some data": torch.randn(10, 11),
...     ("some", "nested", "data"): torch.randn(10, 11, 12),
... }, batch_size=[10, 11])
>>> storage = LazyTensorStorage(100)
>>> storage.set(range(10), data)
>>> len(storage)  # only the first dimension is considered as indexable
10
>>> storage.get(0)
TensorDict(
    fields={
        some data: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
        some: TensorDict(
            fields={
                nested: TensorDict(
                    fields={
                        data: Tensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([11]),
                    device=cpu,
                    is_shared=False)},
            batch_size=torch.Size([11]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([11]),
    device=cpu,
    is_shared=False)
>>> storage.set(0, storage.get(0).zero_()) # zeros the data along index ``0``

此類也支援 tensorclass 資料。

示例

>>> from tensordict import tensorclass
>>> @tensorclass
... class MyClass:
...     foo: torch.Tensor
...     bar: torch.Tensor
>>> data = MyClass(foo=torch.randn(10, 11), bar=torch.randn(10, 11, 12), batch_size=[10, 11])
>>> storage = LazyTensorStorage(10)
>>> storage.set(range(10), data)
>>> storage.get(0)
MyClass(
    bar=Tensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False),
    foo=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
    batch_size=torch.Size([11]),
    device=cpu,
    is_shared=False)
attach(buffer: Any) None

此函式將取樣器附加到此儲存。

從該儲存讀取的緩衝區必須透過呼叫此方法作為已附加實體包含進來。這確保了當儲存中的資料發生變化時,元件能夠感知到這些變化,即使該儲存與其他緩衝區(例如,Priority Samplers)共享。

引數:

buffer – 讀取此儲存的物件。

dump(*args, **kwargs)

dumps() 的別名。

load(*args, **kwargs)

loads() 的別名。

save(*args, **kwargs)

dumps() 的別名。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源