快捷方式

tensorclass

@tensorclass 裝飾器幫助您構建繼承自 TensorDict 行為的自定義類,同時能夠將可能的條目限制為預定義集合或為您的類實現自定義方法。

TensorDict 一樣,@tensorclass 支援巢狀、索引、重塑、項賦值。它還支援張量操作,如 clone, squeeze, torch.cat, split 等。@tensorclass 允許非張量條目,但是所有張量操作都嚴格限制在張量屬性上。

需要為非張量資料實現自定義方法。需要注意的是,@tensorclass 不強制嚴格的型別匹配

>>> from __future__ import annotations
>>> from tensordict.prototype import tensorclass
>>> import torch
>>> from torch import nn
>>> from typing import Optional
>>>
>>> @tensorclass
... class MyData:
...     floatdata: torch.Tensor
...     intdata: torch.Tensor
...     non_tensordata: str
...     nested: Optional[MyData] = None
...
...     def check_nested(self):
...         assert self.nested is not None
>>>
>>> data = MyData(
...   floatdata=torch.randn(3, 4, 5),
...   intdata=torch.randint(10, (3, 4, 1)),
...   non_tensordata="test",
...   batch_size=[3, 4]
... )
>>> print("data:", data)
data: MyData(
  floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
  intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
  non_tensordata='test',
  nested=None,
  batch_size=torch.Size([3, 4]),
  device=None,
  is_shared=False)
>>> data.nested = MyData(
...     floatdata = torch.randn(3, 4, 5),
...     intdata=torch.randint(10, (3, 4, 1)),
...     non_tensordata="nested_test",
...     batch_size=[3, 4]
... )
>>> print("nested:", data)
nested: MyData(
  floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
  intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
  non_tensordata='test',
  nested=MyData(
      floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
      intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
      non_tensordata='nested_test',
      nested=None,
      batch_size=torch.Size([3, 4]),
      device=None,
      is_shared=False),
  batch_size=torch.Size([3, 4]),
  device=None,
  is_shared=False)

正如 TensorDict 的情況一樣,從 v0.4 開始,如果省略批次大小,則認為其為空。

如果提供了非空批次大小,@tensorclass 支援索引。內部會索引張量物件,但是非張量資料保持不變

>>> print("indexed:", data[:2])
indexed: MyData(
   floatdata=Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   intdata=Tensor(shape=torch.Size([2, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
   non_tensordata='test',
   nested=MyData(
      floatdata=Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
      intdata=Tensor(shape=torch.Size([2, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
      non_tensordata='nested_test',
      nested=None,
      batch_size=torch.Size([2, 4]),
      device=None,
      is_shared=False),
   batch_size=torch.Size([2, 4]),
   device=None,
   is_shared=False)

@tensorclass 還支援設定和重置屬性,即使是巢狀物件。

>>> data.non_tensordata = "test_changed"
>>> print("data.non_tensordata: ", repr(data.non_tensordata))
data.non_tensordata: 'test_changed'

>>> data.floatdata = torch.ones(3, 4, 5)
>>> print("data.floatdata:", data.floatdata)
data.floatdata: tensor([[[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]],

      [[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]],

      [[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]]])

>>> # Changing nested tensor data
>>> data.nested.non_tensordata = "nested_test_changed"
>>> print("data.nested.non_tensordata:", repr(data.nested.non_tensordata))
data.nested.non_tensordata: 'nested_test_changed'

@tensorclass 支援對其內容進行形狀和裝置的多個 torch 操作,例如 stack, cat, reshapeto(device)。要獲取支援操作的完整列表,請參閱 tensordict 文件。

這是一個例子:

>>> data2 = data.clone()
>>> cat_tc = torch.cat([data, data2], 0)
>>> print("Concatenated data:", catted_tc)
Concatenated data: MyData(
   floatdata=Tensor(shape=torch.Size([6, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   intdata=Tensor(shape=torch.Size([6, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
   non_tensordata='test_changed',
   nested=MyData(
       floatdata=Tensor(shape=torch.Size([6, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
       intdata=Tensor(shape=torch.Size([6, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
       non_tensordata='nested_test_changed',
       nested=None,
       batch_size=torch.Size([6, 4]),
       device=None,
       is_shared=False),
   batch_size=torch.Size([6, 4]),
   device=None,
   is_shared=False)

序列化

儲存 tensorclass 例項可以透過 memmap 方法實現。儲存策略如下:張量資料將使用記憶體對映張量儲存,而可以使用 json 格式序列化的非張量資料將以此方式儲存。其他資料型別將使用 save() 儲存,該方法依賴於 pickle

反序列化 tensorclass 可以透過 load_memmap() 完成。建立的例項將具有與儲存的例項相同的型別,前提是 tensorclass 在工作環境中可用

>>> data.memmap("path/to/saved/directory")
>>> data_loaded = TensorDict.load_memmap("path/to/saved/directory")
>>> assert isinstance(data_loaded, type(data))

邊緣情況

@tensorclass 支援相等和不等運算子,即使是巢狀物件。請注意,非張量/元資料未經過驗證。這將返回一個具有布林值(用於張量屬性)和 None(用於非張量屬性)的張量類物件

這是一個例子:

>>> print(data == data2)
MyData(
   floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.bool, is_shared=False),
   intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
   non_tensordata=None,
   nested=MyData(
       floatdata=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.bool, is_shared=False),
       intdata=Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
       non_tensordata=None,
       nested=None,
       batch_size=torch.Size([3, 4]),
       device=None,
       is_shared=False),
   batch_size=torch.Size([3, 4]),
   device=None,
   is_shared=False)

@tensorclass 支援設定一個項。但是,在設定項時,會進行非張量/元資料的標識檢查而不是相等性檢查,以避免效能問題。使用者需要確保項的非張量資料與物件匹配,以避免差異。

這是一個例子:

在設定具有不同 non_tensor 資料的項時,會丟擲 UserWarning

>>> data2.non_tensordata = "test_new"
>>> data[0] = data2[0]
UserWarning: Meta data at 'non_tensordata' may or may not be equal, this may result in undefined behaviours

儘管 @tensorclass 支援 cat()stack() 等 torch 函式,但非張量/元資料不會被驗證。torch 操作會在張量資料上執行,並在返回輸出時,會考慮第一個 tensor class 物件的非張量/元資料。使用者需要確保所有 tensor class 物件列表具有相同的非張量資料,以避免差異

這是一個例子:

>>> data2.non_tensordata = "test_new"
>>> stack_tc = torch.cat([data, data2], dim=0)
>>> print(stack_tc)
MyData(
    floatdata=Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
    intdata=Tensor(shape=torch.Size([2, 3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
    non_tensordata='test',
    nested=MyData(
        floatdata=Tensor(shape=torch.Size([2, 3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        intdata=Tensor(shape=torch.Size([2, 3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        non_tensordata='nested_test',
        nested=None,
        batch_size=torch.Size([2, 3, 4]),
        device=None,
        is_shared=False),
    batch_size=torch.Size([2, 3, 4]),
    device=None,
    is_shared=False)

@tensorclass 還支援預分配,您可以將屬性初始化為 None,然後稍後設定它們。請注意,在初始化時,內部的 None 屬性將儲存為非張量/元資料,而在重置時,根據屬性值的型別,它將被儲存為張量資料或非張量/元資料

這是一個例子:

>>> @tensorclass
... class MyClass:
...   X: Any
...   y: Any

>>> data = MyClass(X=None, y=None, batch_size = [3,4])
>>> data.X = torch.ones(3, 4, 5)
>>> data.y = "testing"
>>> print(data)
MyClass(
   X=Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
   y='testing',
   batch_size=torch.Size([3, 4]),
   device=None,
   is_shared=False)

tensorclass([cls, autocast, frozen, nocast, ...])

一個用於建立 tensorclass 類的裝飾器。

TensorClass(*args, **kwargs)

TensorClass 是 @tensorclass 裝飾器的基於繼承的版本。

NonTensorData(data[, _metadata, ...])

MetaData(data[, _metadata, _is_non_tensor, ...])

NonTensorStack(*args, **kwargs)

LazyStackedTensorDict 的一個薄包裝器,用於輕鬆識別非張量資料的堆疊。

from_dataclass(obj, *[, dest_cls, ...])

將 dataclass 例項或型別分別轉換為 tensorclass 例項或型別。

自動型別轉換

警告

自動型別轉換是一項實驗性功能,未來可能會發生變化。與 python<=3.9 的相容性有限。

@tensorclass 作為一項實驗性功能部分支援自動型別轉換。__setattr__, update, update_from_dict 等方法將嘗試將型別註解的條目轉換為所需的 TensorDict / tensorclass 例項(除非發生如下所述的情況)。例如,以下程式碼將把 td 字典轉換為 TensorDict,並將 tc 條目轉換為 MyClass 例項

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     td: TensorDict
...     tc: MyClass
...
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     td={"a": torch.randn(())},
...     tc={"tensor": torch.randn(()), "td": None, "tc": None})
>>> assert isinstance(obj.tensor, torch.Tensor)
>>> assert isinstance(obj.tc, TensorDict)
>>> assert isinstance(obj.td, MyClass)

注意

包含 typing.Optionaltyping.Union 的型別註解條目將與自動型別轉換不相容,但 tensorclass 中的其他條目將相容

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     tc_autocast: MyClass = None
...     tc_not_autocast: Optional[MyClass] = None
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     tc_autocast={"tensor": torch.randn(())},
...     tc_not_autocast={"tensor": torch.randn(())},
... )
>>> assert isinstance(obj.tc_autocast, MyClass)
>>> # because the type is Optional or Union, auto-casting is disabled for
>>> # that variable.
>>> assert not isinstance(obj.tc_not_autocast, MyClass)

如果類中的至少一個條目使用 type0 | type1 語義進行註解,則整個類的自動型別轉換功能將被停用。因為 tensorclass 支援非張量葉子,在這種情況下設定字典將導致將其設定為普通字典而不是 tensor collection 子類(TensorDicttensorclass

>>> @tensorclass
... class MyClass:
...     tensor: torch.Tensor
...     td: TensorDict
...     tc: MyClass | None
...
>>> obj = MyClass(
...     tensor=torch.randn(()),
...     td={"a": torch.randn(())},
...     tc={"tensor": torch.randn(()), "td": None, "tc": None})
>>> assert isinstance(obj.tensor, torch.Tensor)
>>> # tc and td have not been cast
>>> assert isinstance(obj.tc, dict)
>>> assert isinstance(obj.td, dict)

注意

自動型別轉換未對葉子(張量)啟用。原因是此功能與包含 type0 | type1 型別提示語義的型別註解不相容,後者很普遍。允許自動型別轉換將導致非常相似的程式碼,如果型別註解僅有細微差別,行為就會有很大差異。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源