快捷方式

tensorclass

class tensordict.tensorclass(cls: Optional[T] = None, /, *, autocast: bool = False, frozen: bool = False, nocast: bool = False, shadow: bool = False, tensor_only: bool = False)

一個用於建立 tensorclass 類的裝飾器。

tensorclass 類是專門化的 dataclasses.dataclass() 例項,它們可以開箱即用地執行一些預定義的張量操作,例如索引、項賦值、重塑、轉換為裝置或儲存等等。

關鍵字引數:
  • autocast (bool, optional) – 如果為 True,則在設定引數時會強制執行型別。此引數與 autocast 互斥(兩者不能同時為 True)。預設為 False

  • frozen (bool, optional) – 如果為 True,則 tensorclass 的內容無法修改。此引數是為了與 dataclass 相容而提供的,透過類建構函式中的 lock 引數可以獲得類似的行為。預設為 False

  • nocast (bool, optional) – 如果為 True,則 Tensor 相容的型別(如 intnp.ndarray 等)不會被轉換為張量型別。此引數與 autocast 互斥(兩者不能同時為 True)。預設為 False

  • shadow (bool, optional) – 停用欄位名與 TensorDict 保留屬性的驗證。請謹慎使用,這可能會導致意外後果。預設為 False。

  • tensor_only (bool, optional) – 如果為 True,則預期 tensorclass 中的所有項都將是張量例項(張量相容,因為非張量資料會被儘可能轉換為張量)。這可以帶來顯著的速度提升,但會犧牲與非張量資料的靈活互動。預設為 False

tensorclass 可以帶引數或不帶引數使用

示例

>>> @tensorclass
... class X:
...     y: int
>>> X(torch.ones(())).y
tensor(1.)
>>> @tensorclass(autocast=False)
... class X:
...     y: int
>>> X(torch.ones(())).y
tensor(1.)
>>> @tensorclass(autocast=True)
... class X:
...     y: int
>>> X(torch.ones(())).y
1
>>> @tensorclass(nocast=True)
... class X:
...     y: Any
>>> X(1).y
1
>>> @tensorclass(nocast=False)
... class X:
...     y: Any
>>> X(1).y
tensor(1)

示例

>>> from tensordict import tensorclass
>>> import torch
>>> from typing import Optional
>>>
>>> @tensorclass
... class MyData:
...     X: torch.Tensor
...     y: torch.Tensor
...     z: str
...     def expand_and_mask(self):
...         X = self.X.unsqueeze(-1).expand_as(self.y)
...         X = X[self.y]
...         return X
...
>>> data = MyData(
...     X=torch.ones(3, 4, 1),
...     y=torch.zeros(3, 4, 2, 2, dtype=torch.bool),
...     z="test"
...     batch_size=[3, 4])
>>> print(data)
MyData(
    X=Tensor(torch.Size([3, 4, 1]), dtype=torch.float32),
    y=Tensor(torch.Size([3, 4, 2, 2]), dtype=torch.bool),
    z="test"
    batch_size=[3, 4],
    device=None,
    is_shared=False)
>>> print(data.expand_and_mask())
tensor([])
也可以將 tensorclasses 例項巢狀在彼此內部

示例: >>> from tensordict import tensorclass >>> import torch >>> from typing import Optional >>> >>> @tensorclass … class NestingMyData: … nested: MyData … >>> nesting_data = NestingMyData(nested=data, batch_size=[3, 4]) >>> # 儘管資料儲存為 TensorDict,但型別提示有助於我們 >>> # 將資料正確地轉換為正確的型別 >>> assert isinstance(nesting_data.nested, type(data))

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源