• 文件 >
  • 使用 TensorDict 簡化 PyTorch 記憶體管理
快捷方式

使用 TensorDict 簡化 PyTorch 記憶體管理

作者: Tom Begley

在本教程中,您將學習如何透過將 TensorDict 的內容傳送到裝置或利用記憶體對映來控制其在記憶體中的儲存位置。

裝置

建立 TensorDict 時,可以使用 device 關鍵字引數指定裝置。如果設定了 device,則 TensorDict 的所有條目都將放在該裝置上。如果未設定 device,則 TensorDict 中的條目沒有必須在同一裝置上的要求。

在此示例中,我們使用 device="cuda:0" 例項化了一個 TensorDict。當我們列印其內容時,可以看到它們已被移至裝置。

>>> import torch
>>> from tensordict import TensorDict
>>> tensordict = TensorDict({"a": torch.rand(10)}, [10], device="cuda:0")
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([10]),
    device=cuda:0,
    is_shared=True)

如果 TensorDict 的裝置不是 None,則新新增的條目也會被移至該裝置。

>>> tensordict["b"] = torch.rand(10, 10)
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True),
        b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([10]),
    device=cuda:0,
    is_shared=True)

您可以使用 device 屬性來檢查 TensorDict 的當前裝置。

>>> print(tensordict.device)
cuda:0

可以使用 TensorDict.cuda()TensorDict.device(device) 方法將 TensorDict 的內容傳送到裝置,其中 device 是目標裝置。這就像傳送 PyTorch 張量一樣。

>>> tensordict.to(torch.device("cpu"))
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([10, 10]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([10]),
    device=cpu,
    is_shared=False)
>>> tensordict.cuda()
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True),
        b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([10]),
    device=cuda:0,
    is_shared=True)

TensorDict.device 方法需要一個有效的裝置作為引數。如果您想從 TensorDict 中移除裝置以允許不同裝置的條目,則應使用 TensorDict.clear_device 方法。

>>> tensordict.clear_device()
>>> print(tensordict)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([10]), device=cuda:0, dtype=torch.float32, is_shared=True),
        b: Tensor(shape=torch.Size([10, 10]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

記憶體對映張量

tensordict 提供了一個類 MemoryMappedTensor,它允許我們將張量的內容儲存在磁碟上,同時仍然支援快速索引和分批載入內容。有關此功能的實際示例,請參閱 ImageNet 教程

要將 TensorDict 轉換為一系列記憶體對映張量,請使用 TensorDict.memmap_

tensordict = TensorDict({"a": torch.rand(10), "b": {"c": torch.rand(10)}}, [10])
tensordict.memmap_()

print(tensordict)
TensorDict(
    fields={
        a: MemoryMappedTensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: MemoryMappedTensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([10]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=cpu,
    is_shared=False)

或者,可以使用 TensorDict.memmap_like 方法。這將建立一個具有相同結構且值為 MemoryMappedTensor 的新 TensorDict,但它不會將原始張量的內容複製到記憶體對映張量中。這允許您建立記憶體對映的 TensorDict,然後緩慢地填充它,因此通常應優先於 memmap_

tensordict = TensorDict({"a": torch.rand(10), "b": {"c": torch.rand(10)}}, [10])
mm_tensordict = tensordict.memmap_like()

print(mm_tensordict["a"].contiguous())
MemoryMappedTensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

預設情況下,TensorDict 的內容將儲存到磁碟上的臨時位置。但是,如果您想控制儲存位置,可以使用關鍵字引數 prefix="/path/to/root"

TensorDict 的內容將以模仿 TensorDict 本身結構的目錄結構儲存。張量的內容儲存在 NumPy memmap 中,元資料儲存在相關的 PyTorch 儲存檔案中。例如,上面的 TensorDict 儲存如下:

├── a.memmap
├── a.meta.pt
├── b
│ ├── c.memmap
│ ├── c.meta.pt
│ └── meta.pt
└── meta.pt

指令碼總執行時間: (0 分鐘 0.005 秒)

由 Sphinx-Gallery 生成的畫廊

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源