tensorclass¶
- class tensordict.tensorclass(cls: Optional[T] = None, /, *, autocast: bool = False, frozen: bool = False, nocast: bool = False, shadow: bool = False, tensor_only: bool = False)¶
一個用於建立
tensorclass類的裝飾器。tensorclass類是專門化的dataclasses.dataclass()例項,它們可以開箱即用地執行一些預定義的張量操作,例如索引、項賦值、重塑、轉換為裝置或儲存等等。- 關鍵字引數:
autocast (bool, optional) – 如果為
True,則在設定引數時會強制執行型別。此引數與autocast互斥(兩者不能同時為 True)。預設為False。frozen (bool, optional) – 如果為
True,則 tensorclass 的內容無法修改。此引數是為了與 dataclass 相容而提供的,透過類建構函式中的 lock 引數可以獲得類似的行為。預設為False。nocast (bool, optional) – 如果為
True,則 Tensor 相容的型別(如int、np.ndarray等)不會被轉換為張量型別。此引數與autocast互斥(兩者不能同時為 True)。預設為False。shadow (bool, optional) – 停用欄位名與 TensorDict 保留屬性的驗證。請謹慎使用,這可能會導致意外後果。預設為 False。
tensor_only (bool, optional) – 如果為
True,則預期 tensorclass 中的所有項都將是張量例項(張量相容,因為非張量資料會被儘可能轉換為張量)。這可以帶來顯著的速度提升,但會犧牲與非張量資料的靈活互動。預設為False。
tensorclass 可以帶引數或不帶引數使用
示例
>>> @tensorclass ... class X: ... y: int >>> X(torch.ones(())).y tensor(1.) >>> @tensorclass(autocast=False) ... class X: ... y: int >>> X(torch.ones(())).y tensor(1.) >>> @tensorclass(autocast=True) ... class X: ... y: int >>> X(torch.ones(())).y 1 >>> @tensorclass(nocast=True) ... class X: ... y: Any >>> X(1).y 1 >>> @tensorclass(nocast=False) ... class X: ... y: Any >>> X(1).y tensor(1)
示例
>>> from tensordict import tensorclass >>> import torch >>> from typing import Optional >>> >>> @tensorclass ... class MyData: ... X: torch.Tensor ... y: torch.Tensor ... z: str ... def expand_and_mask(self): ... X = self.X.unsqueeze(-1).expand_as(self.y) ... X = X[self.y] ... return X ... >>> data = MyData( ... X=torch.ones(3, 4, 1), ... y=torch.zeros(3, 4, 2, 2, dtype=torch.bool), ... z="test" ... batch_size=[3, 4]) >>> print(data) MyData( X=Tensor(torch.Size([3, 4, 1]), dtype=torch.float32), y=Tensor(torch.Size([3, 4, 2, 2]), dtype=torch.bool), z="test" batch_size=[3, 4], device=None, is_shared=False) >>> print(data.expand_and_mask()) tensor([])
- 也可以將 tensorclasses 例項巢狀在彼此內部
示例: >>> from tensordict import tensorclass >>> import torch >>> from typing import Optional >>> >>> @tensorclass … class NestingMyData: … nested: MyData … >>> nesting_data = NestingMyData(nested=data, batch_size=[3, 4]) >>> # 儘管資料儲存為 TensorDict,但型別提示有助於我們 >>> # 將資料正確地轉換為正確的型別 >>> assert isinstance(nesting_data.nested, type(data))