快捷方式

VmapModule

class torchrl.modules.VmapModule(*args, **kwargs)[原始碼]

一個 TensorDictModule 包裝器,用於對輸入進行 vmap 操作。

它旨在與接受的批次維度比提供的批次維度少一個的模組一起使用。透過使用此包裝器,可以隱藏一個批次維度並滿足被包裝模組的要求。

引數:
  • module (TensorDictModuleBase) – 要進行 vmap 操作的模組。

  • vmap_dim (int, optional) – vmap 的輸入和輸出維度。如果未提供,則假定 tensordict 的最後一個維度。

注意

由於 vmap 需要控制輸入的批次大小,因此此模組不支援分派的引數

示例

>>> lam = TensorDictModule(lambda x: x[0], in_keys=["x"], out_keys=["y"])
>>> sample_in = torch.ones((10,3,2))
>>> sample_in_td = TensorDict({"x":sample_in}, batch_size=[10])
>>> lam(sample_in)
>>> vm = VmapModule(lam, 0)
>>> vm(sample_in_td)
>>> assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all()
forward(tensordict)[原始碼]

定義每次呼叫時執行的計算。

所有子類都應重寫此方法。

注意

儘管前向傳播的實現需要在此函式中定義,但您應該在之後呼叫 Module 例項而不是此函式,因為前者會處理註冊的鉤子,而後者則會靜默忽略它們。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源