快捷方式

TensorClass

class tensordict.TensorClass(*args, **kwargs)

TensorClass 是 `@tensorclass` 裝飾器的基於繼承的版本。

TensorClass 允許你編寫比使用 `@tensorclass` 裝飾器構建的 dataclasses 具有更好的型別檢查和更具 Pythonic 的程式碼。

示例

>>> from typing import Any
>>> import torch
>>> from tensordict import TensorClass
>>> class Foo(TensorClass):
...     tensor: torch.Tensor
...     non_tensor: Any
...     nested: Any = None
>>> foo = Foo(tensor=torch.randn(3), non_tensor="a string!", nested=None, batch_size=[3])
>>> print(foo)
Foo(
    non_tensor=NonTensorData(data=a string!, batch_size=torch.Size([3]), device=None),
    tensor=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
    nested=None,
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
關鍵字引數:
  • batch_size (torch.Size, optional) – TensorDict 的批次大小。預設為 None

  • device (torch.device, optional) – 將建立 TensorDict 的裝置。預設為 None

  • frozen (bool, optional) – 如果為 True,則生成的類或例項將是不可變的。預設為 False

  • autocast (bool, optional) – 如果為 True,則為生成的類或例項啟用自動型別轉換。預設為 False

  • nocast (bool, optional) – 如果為 True,則停用為生成的類或例項進行的任何型別轉換。預設為 False

  • tensor_only (bool, optional) – 如果為 True,則預期 tensorclass 中的所有項都將是張量例項(張量相容,因為非張量資料會被儘可能轉換為張量)。這可以帶來顯著的速度提升,但會犧牲與非張量資料的靈活互動。預設為 False

  • shadow (bool, optional) – 停用欄位名與 TensorDict 保留屬性的驗證。請謹慎使用,這可能會導致意外後果。預設為 False。

你可以透過兩種方式傳遞布林關鍵字引數(“autocast”“nocast”“frozen”“tensor_only”“shadow”):使用

方括號或關鍵字引數。

示例

>>> class Foo(TensorClass["autocast"]):
...     integer: int
>>> Foo(integer=torch.ones(())).integer
1
>>> class Foo(TensorClass, autocast=True):  # equivalent
...     integer: int
>>> Foo(integer=torch.ones(())).integer
1
>>> class Foo(TensorClass["nocast"]):
...     integer: int
>>> Foo(integer=1).integer
1
>>> class Foo(TensorClass["nocast", "frozen"]):  # multiple keywords can be used
...     integer: int
>>> Foo(integer=1).integer
1
>>> class Foo(TensorClass, nocast=True):  # equivalent
...     integer: int
>>> Foo(integer=1).integer
1
>>> class Foo(TensorClass):
...     integer: int
>>> Foo(integer=1).integer
tensor(1)

警告

TensorClass 本身沒有被裝飾為 tensorclass,但其子類將會。這是因為我們無法預知 `frozen` 引數是否會被設定,如果設定了,它可能與父類衝突(子類不能是 frozen 的,如果父類不是)。

dumps(prefix: Optional[str] = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) Any

將tensordict儲存到磁碟。

此函式是 `memmap()` 的代理。

from_tensordict(tensordict: TensorDictBase, non_tensordict: Optional[dict] = None, safe: bool = True) Any

用於例項化新張量類物件的張量類包裝器。

引數:
  • tensordict (TensorDictBase) – 張量型別的字典

  • non_tensordict (dict) – 包含非張量和巢狀張量類物件的字典

  • safe (bool) – 如果 tensordict 不是 TensorDictBase 例項,則是否引發錯誤

get(key: NestedKey, *args, **kwargs)

獲取輸入鍵對應的儲存值。

引數:
  • key (str, str 的元組) – 要查詢的鍵。如果是 str 的元組,則等同於鏈式呼叫 getattr。

  • default – 如果在張量類中找不到鍵,則返回預設值。

返回:

儲存在輸入鍵下的值

classmethod load(prefix: str | pathlib.Path, *args, **kwargs) Any

從磁碟載入 tensordict。

此類方法是 `load_memmap()` 的代理。

load_(prefix: str | pathlib.Path, *args, **kwargs)

在當前 tensordict 中從磁碟載入 tensordict。

此類方法是 load_memmap_() 的代理。

classmethod load_memmap(prefix: str | pathlib.Path, device: Optional[device] = None, non_blocking: bool = False, *, out: Optional[TensorDictBase] = None) Any

從磁碟載入記憶體對映的 tensordict。

引數:
  • prefix (str資料夾路徑) – 應從中獲取已儲存 tensordict 的資料夾路徑。

  • device (torch.device等效項, 可選) – 如果提供,資料將非同步轉換為該裝置。支援 `"meta"` 裝置,在這種情況下,資料不會被載入,而是建立一組空的 "meta" 張量。這對於在不實際開啟任何檔案的情況下了解模型大小和結構很有用。

  • non_blocking (bool, 可選) – 如果為 `True`,則在將張量載入到裝置後不會呼叫同步。預設為 `False`。

  • out (TensorDictBase, 可選) – 應將資料寫入其中的可選 tensordict。

示例

>>> from tensordict import TensorDict
>>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0)
>>> td.memmap("./saved_td")
>>> td_load = TensorDict.load_memmap("./saved_td")
>>> assert (td == td_load).all()

此方法還允許載入巢狀的 tensordicts。

示例

>>> nested = TensorDict.load_memmap("./saved_td/nested")
>>> assert nested["e"] == 0

tensordict 也可以在“meta”裝置上載入,或者作為假張量載入。

示例

>>> import tempfile
>>> td = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}})
>>> with tempfile.TemporaryDirectory() as path:
...     td.save(path)
...     td_load = TensorDict.load_memmap(path, device="meta")
...     print("meta:", td_load)
...     from torch._subclasses import FakeTensorMode
...     with FakeTensorMode():
...         td_load = TensorDict.load_memmap(path)
...         print("fake:", td_load)
meta: TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=meta,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=meta,
    is_shared=False)
fake: TensorDict(
    fields={
        a: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: TensorDict(
            fields={
                c: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)
load_state_dict(state_dict: dict[str, Any], strict=True, assign=False, from_flatten=False)

嘗試將 state_dict 載入到目標張量類中(原地)。

memmap(prefix: Optional[str] = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) Any

將所有張量寫入記憶體對映的 Tensor 中,並放入新的 tensordict。

引數:
  • prefix (str) – 記憶體對映張量將儲存的目錄字首。目錄樹結構將模仿 tensordict 的結構。

  • copy_existing (bool) – 如果為 False(預設值),並且 tensordict 中某項已是儲存在磁碟上的張量且關聯了檔案,但未按 prefix 儲存到正確位置,則會引發異常。如果為 True,則任何現有張量都將被複制到新位置。

關鍵字引數:
  • num_threads (int, 可選) – 用於寫入 memmap 張量的執行緒數。預設為 0

  • return_early (bool, 可選) – 如果設定為 Truenum_threads>0,則該方法將返回 tensordict 的一個 future。

  • share_non_tensor (bool, 可選) – 如果設定為 True,則非張量資料將在程序之間共享,並且在單個節點內的任何工作者上進行的寫入操作(例如就地更新或設定)將更新所有其他工作者上的值。如果非張量葉子節點數量很多(例如,共享大量非張量資料),這可能會導致 OOM 或類似錯誤。預設為 False

  • existsok (bool, optional) – 如果為 False,則如果同一路徑下已存在張量,將引發異常。預設為 True

然後,Tensordict 被鎖定,這意味著任何非就地寫入操作(例如重新命名、設定或刪除條目)都將引發異常。一旦 tensordict 被解鎖,記憶體對映屬性將變為 False,因為不能保證跨程序身份。

返回:

返回一個新的 tensordict,其中張量儲存在磁碟上(如果 return_early=False),否則返回一個 TensorDictFuture 例項。

注意

以這種方式序列化對於深度巢狀的 tensordicts 來說可能很慢,因此不建議在訓練迴圈中呼叫此方法。

memmap_(prefix: Optional[str] = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True) Any

將所有張量原地寫入相應的記憶體對映張量。

引數:
  • prefix (str) – 記憶體對映張量將儲存的目錄字首。目錄樹結構將模仿 tensordict 的結構。

  • copy_existing (bool) – 如果為 False(預設值),並且 tensordict 中某項已是儲存在磁碟上的張量且關聯了檔案,但未按 prefix 儲存到正確位置,則會引發異常。如果為 True,則任何現有張量都將被複制到新位置。

關鍵字引數:
  • num_threads (int, 可選) – 用於寫入 memmap 張量的執行緒數。預設為 0

  • return_early (bool, optional) – 如果為 Truenum_threads>0,則方法將返回一個 tensordict 的 future。生成的 tensordict 可以使用 future.result() 進行查詢。

  • share_non_tensor (bool, 可選) – 如果設定為 True,則非張量資料將在程序之間共享,並且在單個節點內的任何工作者上進行的寫入操作(例如就地更新或設定)將更新所有其他工作者上的值。如果非張量葉子節點數量很多(例如,共享大量非張量資料),這可能會導致 OOM 或類似錯誤。預設為 False

  • existsok (bool, optional) – 如果為 False,則如果同一路徑下已存在張量,將引發異常。預設為 True

然後,Tensordict 被鎖定,這意味著任何非就地寫入操作(例如重新命名、設定或刪除條目)都將引發異常。一旦 tensordict 被解鎖,記憶體對映屬性將變為 False,因為不能保證跨程序身份。

返回:

如果 return_early=False,則返回 self,否則返回 TensorDictFuture 例項。

注意

以這種方式序列化對於深度巢狀的 tensordicts 來說可能很慢,因此不建議在訓練迴圈中呼叫此方法。

memmap_like(prefix: Optional[str] = None, copy_existing: bool = False, *, existsok: bool = True, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) Any

建立一個無內容的記憶體對映 tensordict,其形狀與原始 tensordict 相同。

引數:
  • prefix (str) – 記憶體對映張量將儲存的目錄字首。目錄樹結構將模仿 tensordict 的結構。

  • copy_existing (bool) – 如果為 False(預設值),並且 tensordict 中某項已是儲存在磁碟上的張量且關聯了檔案,但未按 prefix 儲存到正確位置,則會引發異常。如果為 True,則任何現有張量都將被複制到新位置。

關鍵字引數:
  • num_threads (int, 可選) – 用於寫入 memmap 張量的執行緒數。預設為 0

  • return_early (bool, 可選) – 如果設定為 Truenum_threads>0,則該方法將返回 tensordict 的一個 future。

  • share_non_tensor (bool, 可選) – 如果設定為 True,則非張量資料將在程序之間共享,並且在單個節點內的任何工作者上進行的寫入操作(例如就地更新或設定)將更新所有其他工作者上的值。如果非張量葉子節點數量很多(例如,共享大量非張量資料),這可能會導致 OOM 或類似錯誤。預設為 False

  • existsok (bool, optional) – 如果為 False,則如果同一路徑下已存在張量,將引發異常。預設為 True

然後,Tensordict 被鎖定,這意味著任何非就地寫入操作(例如重新命名、設定或刪除條目)都將引發異常。一旦 tensordict 被解鎖,記憶體對映屬性將變為 False,因為不能保證跨程序身份。

返回:

如果 return_early=False,則建立一個新的 TensorDict 例項,其中資料儲存為記憶體對映張量;否則,建立一個 TensorDictFuture 例項。

注意

這是將一組大型緩衝區寫入磁碟的推薦方法,因為 `memmap_()` 將會複製資訊,這對於大型內容來說可能會很慢。

示例

>>> td = TensorDict({
...     "a": torch.zeros((3, 64, 64), dtype=torch.uint8),
...     "b": torch.zeros(1, dtype=torch.int64),
... }, batch_size=[]).expand(1_000_000)  # expand does not allocate new memory
>>> buffer = td.memmap_like("/path/to/dataset")
memmap_refresh_()

如果記憶體對映的 tensordict 具有 saved_path,則重新整理其內容。

如果沒有任何路徑與之關聯,此方法將引發異常。

save(prefix: Optional[str] = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False) Any

將tensordict儲存到磁碟。

此函式是 `memmap()` 的代理。

set(key: NestedKey, value: Any, inplace: bool = False, non_blocking: bool = False)

設定一個新的鍵值對。

引數:
  • key (str, tuple of str) – 要設定的鍵名。如果是字串元組,則等同於鏈式呼叫 getattr,然後最後呼叫 setattr。

  • value (Any) – 要儲存在張量類中的值

  • inplace (bool, optional) – 如果為 True,則 set 將嘗試就地更新值。如果為 False 或鍵不存在,則值將簡單地寫入其目標位置。

返回:

self

state_dict(destination=None, prefix='', keep_vars=False, flatten=False) dict[str, Any]

返回一個 state_dict 字典,可用於儲存和載入張量類的資料。

to_tensordict(*, retain_none: Optional[bool] = None) TensorDict

將張量類轉換為常規 TensorDict。

複製所有條目。記憶體對映和共享記憶體張量將被轉換為常規張量。

引數:

retain_none (bool) – 如果 True,則 None 值將被寫入 tensordict。否則,它們將被丟棄。預設值:True

返回:

包含與張量類相同值的新的 TensorDict 物件。

unbind(dim: int)

返回沿指定維度解綁的索引張量類例項的元組。

結果張量類例項將共享初始張量類例項的儲存。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源