快捷方式

WrapModule

class tensordict.nn.WrapModule(*args, **kwargs)

一個圍繞任何處理 TensorDict 例項的可呼叫物件的包裝器。

當構建 TensorDictSequential 堆疊且某個轉換需要整個 TensorDict 例項可見時,此包裝器很有用。

引數:

func (Callable[[TensorDictBase], TensorDictBase]) – 一個可呼叫函式,它接受一個 TensorDictBase 例項並返回一個轉換後的 TensorDictBase 例項。

關鍵字引數:
  • inplace (bool, optional) – 如果為 True,則輸入 TensorDict 將被原地修改。否則,將返回一個新的 TensorDict(如果函式不原地修改它並返回它)。預設為 False

  • in_keys (list of NestedKey, optional) – 如果提供,則指示模組讀取哪些條目。這不會被檢查,僅用於告知 TensorDictSequential 被包裝模組的輸入鍵。預設為 []

  • out_keys (list of NestedKey, optional) – 如果提供,則指示模組寫入哪些條目。這不會被檢查,僅用於告知 TensorDictSequential 被包裝模組的輸出鍵。預設為 []

示例

>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod, WrapModule
>>> seq = Seq(
...     Mod(lambda x: x * 2, in_keys=["x"], out_keys=["y"]),
...     WrapModule(lambda td: td.reshape(-1)),
... )
>>> td = TensorDict(x=torch.ones(3, 4, 5), batch_size=[3, 4])
>>> td = Seq(td)
>>> assert td.shape == (12,)
>>> assert (td["y"] == 2).all()
>>> assert td["y"].shape == (12, 5)
forward(data: TensorDictBase) TensorDictBase

定義每次呼叫時執行的計算。

所有子類都應重寫此方法。

注意

儘管前向傳播的實現需要在此函式中定義,但您應該在之後呼叫 Module 例項而不是此函式,因為前者會處理註冊的鉤子,而後者則會靜默忽略它們。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源