TensorDictModuleBase¶
- class tensordict.nn.TensorDictModuleBase(*args, **kwargs)¶
TensorDict 模組的基類。
TensorDictModule 子類以
in_keys和out_keys鍵列表為特徵,這些列表指示要讀取的輸入條目和期望寫入的輸出條目。forward 方法的輸入/輸出簽名應始終遵循約定
>>> tensordict_out = module.forward(tensordict_in)
與
TensorDictModule不同,TensorDictModuleBase 通常透過子類化來使用:您可以將任何 Python 函式包裝在 TensorDictModuleBase 子類中,只要子類的 forward 方法讀取和寫入 tensordict(或相關型別)例項即可。in_keys 和 out_keys 應正確指定。例如,out_keys 可以使用
select_out_keys()動態縮減。示例
>>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModuleBase >>> class Mod(TensorDictModuleBase): ... in_keys = ["a"] # can also be specified during __init__ ... out_keys = ["b", "c"] ... def forward(self, tensordict): ... b = tensordict["a"].clone() ... c = b + 1 ... return tensordict.replace({"b": b, "c": c}) >>> mod = Mod() >>> td = mod(TensorDict(a=0)) >>> td["b"] tensor(0) >>> td["c"] tensor(1) >>> mod.select_out_keys("c") >>> td = mod(TensorDict(a=0)) >>> td["c"] tensor(1) >>> assert "b" not in td
- static is_tdmodule_compatible(module)¶
檢查模組是否與 TensorDictModule API 相容。
- reset_out_keys()¶
將
out_keys屬性重置為其原始值。返回: 相同的模組,但
out_keys值已重置。示例
>>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> import torch >>> mod = TensorDictModule(lambda x, y: (x+2, y+2), in_keys=["a", "b"], out_keys=["c", "d"]) >>> mod.select_out_keys("d") >>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []) >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> mod.reset_out_keys() >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
- reset_parameters_recursive(parameters: Optional[TensorDictBase] = None) Optional[TensorDictBase]¶
遞迴地重置模組及其子模組的引數。
- 引數:
parameters (TensorDict of parameters, optional) – 如果設定為 None,則模組將使用 self.parameters() 重置。否則,我們將就地重置 tensordict 中的引數。這對於引數本身不儲存在模組中的函式式模組很有用。
- 返回:
新引數的 tensordict,僅當 parameters 不為 None 時返回。
示例
>>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> net = nn.Sequential(nn.Linear(2,3), nn.ReLU()) >>> old_param = net[0].weight.clone() >>> module = TensorDictModule(net, in_keys=['bork'], out_keys=['dork']) >>> module.reset_parameters() >>> (old_param == net[0].weight).any() tensor(False)
此方法還支援函式式引數取樣
>>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule >>> from torch import nn >>> net = nn.Sequential(nn.Linear(2,3), nn.ReLU()) >>> module = TensorDictModule(net, in_keys=['bork'], out_keys=['dork']) >>> params = TensorDict.from_module(module) >>> old_params = params.clone(recurse=True) >>> module.reset_parameters(params) >>> (old_params == params).any() False
- select_out_keys(*out_keys) TensorDictModuleBase¶
選擇將在輸出 tensordict 中找到的鍵。
當一個人想丟棄複雜圖中的中間鍵,或者當這些鍵的存在可能觸發意外行為時,這很有用。
原始
out_keys仍然可以透過module.out_keys_source訪問。- 引數:
*out_keys (字串序列 或 字串元組) – 應在輸出 tensordict 中找到的 out_keys。
返回: 相同的模組,以就地修改方式返回,並更新了
out_keys。最簡單的用法是配合
TensorDictModule示例
>>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> import torch >>> mod = TensorDictModule(lambda x, y: (x+2, y+2), in_keys=["a", "b"], out_keys=["c", "d"]) >>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []) >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> mod.select_out_keys("d") >>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []) >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
此功能也將適用於分派的引數: .. rubric:: 示例
>>> mod(torch.zeros(()), torch.ones(())) tensor(2.)
此更改將就地進行(即返回相同的模組,並更新 out_keys 列表)。可以使用
TensorDictModuleBase.reset_out_keys()方法將其恢復。示例
>>> mod.reset_out_keys() >>> mod(TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, [])) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
這也將適用於其他類,例如 Sequential: .. rubric:: 示例
>>> from tensordict.nn import TensorDictSequential >>> seq = TensorDictSequential( ... TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["y"]), ... TensorDictModule(lambda x: x+1, in_keys=["y"], out_keys=["z"]), ... ) >>> td = TensorDict({"x": torch.zeros(())}, []) >>> seq(td) TensorDict( fields={ x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), y: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> seq.select_out_keys("z") >>> td = TensorDict({"x": torch.zeros(())}, []) >>> seq(td) TensorDict( fields={ x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)