快捷方式

TensorDictModule

class tensordict.nn.TensorDictModule(*args, **kwargs)

TensorDictModule 是一個 Python 包裝器,圍繞一個 nn.Module,用於讀寫 TensorDict。

引數:
  • module (Callable[[Any], Any]) – 一個可呼叫物件,通常是一個 torch.nn.Module,用於將輸入對映到輸出引數空間。它的 forward 方法可以返回單個張量、張量元組,甚至是一個字典。在後一種情況下,TensorDictModule 的輸出鍵將用於填充輸出 tensordict(即 out_keys 中存在的鍵應該存在於 module forward 方法返回的字典中)。

  • in_keys (iterable of NestedKeys, Dict[NestedStr, str]) – 從輸入 tensordict 讀取並傳遞給模組的鍵。如果包含多個元素,則按 in_keys 可迭代物件的順序傳遞值。如果 in_keys 是一個字典,其鍵必須對應於 tensordict 中要讀取的鍵,其值必須與函式簽名中的關鍵字引數名稱匹配。如果 out_to_in_mapTrue,則對映會被反轉,以便鍵對應於函式簽名中的關鍵字引數。

  • out_keys (iterable of str) – 要寫入輸入 tensordict 的鍵。out_keys 的長度必須與嵌入模組返回的張量數量匹配。使用“_”作為鍵可以避免將張量寫入輸出。

關鍵字引數:
  • out_to_in_map (bool, optional) – 如果為 True(預設),則 in_keys 的讀取方式就像鍵是 forward() 方法的引數鍵,值是輸入 TensorDict 中的鍵。如果為 False,則鍵被視為輸入鍵,值被視為方法的引數鍵。

  • inplace (bool or string, optional) –

    如果為 True(預設),則模組的輸出將被寫入傳遞給 forward() 方法的 tensordict。如果為 False,則會建立一個具有空批次大小和無裝置的新的 TensorDict。如果為 "empty",則將使用 empty() 來建立輸出 tensordict。

    注意

    如果 inplace=False 且傳遞給模組的 tensordict 是 TensorDict 以外的 TensorDictBase 子類,則輸出仍將是 TensorDict 例項。其批次大小將為空,並且沒有裝置。將其設定為 "empty" 以獲得相同的 TensorDictBase 子型別、相同的批次大小和裝置。在執行時使用 tensordict_out(見下文)可以更精細地控制輸出。

    注意

    如果 inplace=False 並且 forward() 方法中傳遞了 tensordict_out,則 tensordict_out 將優先。這是獲得 tensordict_out 的方式,傳遞給模組的 tensordict 是 TensorDict 以外的 TensorDictBase 子類,輸出仍將是 TensorDict 例項。

  • method (str, optional) – 要在模組中呼叫的方法(如果存在)。預設為 __call__

  • method_kwargs (Dict[str, Any], optional) – 要傳遞給被呼叫模組方法的附加關鍵字引數。

  • strict (bool, optional) – 如果為 True,則模組將在輸入 tensordict 中缺少任何輸入時引發異常。否則,將使用 None 值作為佔位符。預設為 False

  • get_kwargs (dict[str, Any], optional) – 要傳遞給 get() 方法的附加關鍵字引數。這在處理不規則張量時尤其有用(參見 get())。預設為 {}

將神經網路嵌入 TensorDictModule 只需指定輸入和輸出鍵。TensorDictModule 支援函式式和常規 nn.Module 物件。在函式式情況下,必須指定 ‘params’(以及 ‘buffers’)關鍵字引數。

示例

>>> from tensordict import TensorDict
>>> # one can wrap regular nn.Module
>>> module = TensorDictModule(nn.Transformer(128), in_keys=["input", "tgt"], out_keys=["out"])
>>> input = torch.ones(2, 3, 128)
>>> tgt = torch.zeros(2, 3, 128)
>>> data = TensorDict({"input": input, "tgt": tgt}, batch_size=[2, 3])
>>> data = module(data)
>>> print(data)
TensorDict(
    fields={
        input: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False),
        out: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False),
        tgt: Tensor(shape=torch.Size([2, 3, 128]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2, 3]),
    device=None,
    is_shared=False)

我們也可以直接傳遞張量。

示例

>>> out = module(input, tgt)
>>> assert out.shape == input.shape
>>> # we can also wrap regular functions
>>> module = TensorDictModule(lambda x: (x-1, x+1), in_keys=[("input", "x")], out_keys=[("output", "x-1"), ("output", "x+1")])
>>> module(TensorDict({("input", "x"): torch.zeros(())}, batch_size=[]))
TensorDict(
    fields={
        input: TensorDict(
            fields={
                x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False),
        output: TensorDict(
            fields={
                x+1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
                x-1: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

我們可以使用 TensorDictModule 來填充 tensordict。

示例

>>> module = TensorDictModule(lambda: torch.randn(3), in_keys=[], out_keys=["x"])
>>> print(module(TensorDict({}, batch_size=[])))
TensorDict(
    fields={
        x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

另一個功能是傳遞一個字典作為輸入鍵,以控制值到特定關鍵字引數的分派。

示例

>>> module = TensorDictModule(lambda x, *, y: x+y,
...     in_keys={'1': 'x', '2': 'y'}, out_keys=['z'], out_to_in_map=False
...     )
>>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, []))
>>> td['z']
tensor(3.)

如果將 out_to_in_map 設定為 True,則 in_keys 對映會被反轉。這樣,就可以將相同的輸入鍵用於不同的關鍵字引數。

示例

>>> module = TensorDictModule(lambda x, *, y, z: x+y+z,
...     in_keys={'x': '1', 'y': '2', z: '2'}, out_keys=['t'], out_to_in_map=True
...     )
>>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, []))
>>> td['t']
tensor(5.)

我們可以指定模組內要呼叫的方法。與使用 lambda 函式或類似函式包裝模組方法相比,它的優點是模組屬性(params、buffers、submodules)將被暴露。

示例

>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod
>>> from torch import nn
>>> import torch
>>>
>>> class MyNet(nn.Module):
...     def my_func(self, tensor: torch.Tensor, *, an_integer: int):
...         return tensor + an_integer
...
>>> s = Seq(
...     {
...         "a": lambda td: td+1,
...         "b": lambda td: td * 2,
...         "c": Mod(MyNet(), in_keys=["a"], out_keys=["b"], method="my_func", method_kwargs={"an_integer": 2}),
...     }
... )
>>> td = s(TensorDict(a=0))
>>> print(td)
>>>
>>> assert td["b"] == 4

對 tensordict 模組進行函式式呼叫很容易。

示例

>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,])
>>> module = torch.nn.GRUCell(4, 8)
>>> td_module = TensorDictModule(
...    module=module, in_keys=["input", "hidden"], out_keys=["output"]
... )
>>> params = TensorDict.from_module(td_module)
>>> # functional API
>>> with params.to_module(td_module):
...     td_functional = td_module(td.clone())
>>> print(td_functional)
TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
在有狀態的情況下。
>>> module = torch.nn.GRUCell(4, 8)
>>> td_module = TensorDictModule(
...    module=module, in_keys=["input", "hidden"], out_keys=["output"]
... )
>>> td_stateful = td_module(td.clone())
>>> print(td_stateful)
TensorDict(
    fields={
        hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
        output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)
forward(tensordict: TensorDictBase = None, args=None, *, tensordict_out: tensordict.base.TensorDictBase | None = None, **kwargs: Any) TensorDictBase

當 tensordict 引數未設定時,kwargs 用於建立 TensorDict 的例項。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源