VmapModule¶
- class torchrl.modules.VmapModule(*args, **kwargs)[原始碼]¶
一個 TensorDictModule 包裝器,用於對輸入進行 vmap 操作。
它旨在與接受的批次維度比提供的批次維度少一個的模組一起使用。透過使用此包裝器,可以隱藏一個批次維度並滿足被包裝模組的要求。
- 引數:
module (TensorDictModuleBase) – 要進行 vmap 操作的模組。
vmap_dim (int, optional) – vmap 的輸入和輸出維度。如果未提供,則假定 tensordict 的最後一個維度。
注意
由於 vmap 需要控制輸入的批次大小,因此此模組不支援分派的引數
示例
>>> lam = TensorDictModule(lambda x: x[0], in_keys=["x"], out_keys=["y"]) >>> sample_in = torch.ones((10,3,2)) >>> sample_in_td = TensorDict({"x":sample_in}, batch_size=[10]) >>> lam(sample_in) >>> vm = VmapModule(lam, 0) >>> vm(sample_in_td) >>> assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all()