• 文件 >
  • 使用 TensorDict 預分配記憶體
快捷方式

使用 TensorDict 進行記憶體預分配

作者: Tom Begley

在本教程中,您將學習如何利用 TensorDict 中的記憶體預分配功能。

假設我們有一個函式,它返回一個 TensorDict

import torch
from tensordict.tensordict import TensorDict


def make_tensordict():
    return TensorDict({"a": torch.rand(3), "b": torch.rand(3, 4)}, [3])

也許我們想多次呼叫此函式,並將結果填充到一個 TensorDict 中。

N = 10
tensordict = TensorDict({}, batch_size=[N, 3])

for i in range(N):
    tensordict[i] = make_tensordict()

print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([10, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10, 3]),
    device=None,
    is_shared=False)

由於我們已經指定了 tensordictbatch_size,在迴圈的第一次迭代中,我們用空張量填充 tensordict,其第一個維度的大小為 N,其餘維度由 make_tensordict 的返回值確定。在上面的示例中,我們為鍵 "a" 預分配了一個大小為 torch.Size([10, 3]) 的零陣列,併為鍵 "b" 預分配了一個大小為 torch.Size([10, 3, 4]) 的陣列。後續的迴圈迭代是就地寫入的。因此,如果並非所有值都已填充,它們將獲得預設值零。

讓我們透過逐步分析上述迴圈來演示正在發生的情況。我們首先初始化一個空的 TensorDict

N = 10
tensordict = TensorDict({}, batch_size=[N, 3])
print(tensordict)
TensorDict(
    fields={
    },
    batch_size=torch.Size([10, 3]),
    device=None,
    is_shared=False)

第一次迭代後,tensordict 已經為 "a""b" 預先填充了張量。這些張量包含零,除了我們為其分配了隨機值的第一行。

random_tensordict = make_tensordict()
tensordict[0] = random_tensordict

assert (tensordict[1:] == 0).all()
assert (tensordict[0] == random_tensordict).all()

print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([10, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10, 3]),
    device=None,
    is_shared=False)

在後續的迭代中,我們對預分配的張量進行就地更新。

a = tensordict["a"]
random_tensordict = make_tensordict()
tensordict[1] = random_tensordict

# the same tensor is stored under "a", but the values have been updated
assert tensordict["a"] is a
assert (tensordict[:2] != 0).all()

指令碼總執行時間: (0 分 0.003 秒)

由 Sphinx-Gallery 生成的畫廊

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源