from_dataclass¶
- class tensordict.from_dataclass(obj: Any, *, dest_cls: Optional[Type] = None, auto_batch_size: bool = False, batch_dims: Optional[int] = None, batch_size: Optional[Size] = None, frozen: bool = False, autocast: bool = False, nocast: bool = False, inplace: bool = False, shadow: bool = False, tensor_only: bool = False, device: Optional[device] = None)¶
將 dataclass 例項或型別分別轉換為 tensorclass 例項或型別。
此函式接受一個 dataclass 例項或 dataclass 型別,並將其轉換為 tensor 相容的類,還可以選擇性地應用各種配置,例如自動批處理、不可變性和型別轉換。
- 引數:
obj (Any) – 要轉換的 dataclass 例項或型別。如果提供了型別,則返回新類。
- 關鍵字引數:
dest_cls (tensorclass, optional) – 用於對映資料的 tensorclass 型別。如果未提供,則建立新類。如果
obj是一個型別,則此引數無效。auto_batch_size (bool, optional) – 如果為
True,將自動確定並應用批次大小到結果物件。預設為False。batch_dims (int, optional) – 如果 auto_batch_size 為
True,則定義輸出 tensordict 應具有的維度數。預設為None(每個級別完全批次大小)。batch_size (torch.Size, optional) – TensorDict 的批次大小。預設為
None。frozen (bool, optional) – 如果為
True,則結果類或例項將是不可變的。預設為False。autocast (bool, optional) – 如果為
True,則為結果類或例項啟用自動型別轉換。預設為False。nocast (bool, optional) – 如果為
True,則停用結果類或例項的任何型別轉換。預設為False。tensor_only (bool, optional) – 如果為
True,則預期 tensorclass 中的所有項都將是張量例項(張量相容,因為非張量資料會被儘可能轉換為張量)。這可以帶來顯著的速度提升,但會犧牲與非張量資料的靈活互動。預設為False。inplace (bool, optional) – 如果為
True,則將就地修改提供的 dataclass 型別。預設為False。如果提供了例項,則此引數無效。device (torch.device, optional) – 將建立 TensorDict 的裝置。預設為
None。shadow (bool, optional) – 停用欄位名與 TensorDict 保留屬性的驗證。請謹慎使用,這可能會導致意外後果。預設為 False。
- 返回:
一個派生自提供的 dataclass 的 tensor 相容類或例項。
- 丟擲:
TypeError – 如果提供的輸入不是 dataclass 例項或型別。
示例
>>> from dataclasses import dataclass >>> import torch >>> from tensordict.tensorclass import from_dataclass >>> >>> @dataclass >>> class X: ... a: int ... b: torch.Tensor ... >>> x = X(0, 0) >>> x2 = from_dataclass(x) >>> print(x2) X( a=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), batch_size=torch.Size([]), device=None, is_shared=False) >>> X2 = from_dataclass(X, autocast=True) >>> print(X2(a=0, b=0)) X( a=NonTensorData(data=0, batch_size=torch.Size([]), device=None), b=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), batch_size=torch.Size([]), device=None, is_shared=False)
注意
如果提供了 dataclass 型別,則會返回一個帶有指定配置的新類。如果提供了 dataclass 例項,則會返回 tensor 相容類的新例項。auto_batch_size、frozen、autocast 和 nocast 選項允許靈活配置結果類或例項。
警告
雖然
from_dataclass()預設返回TensorDict例項,但此方法將返回一個 tensorclass 例項或型別。