TensorDictSequential¶
- class tensordict.nn.TensorDictSequential(*args, **kwargs)¶
一個 TensorDictModule 的序列。
類似於
nn.Sequence,它將一個張量透過一系列對映,每個對映讀取並寫入單個張量,這個模組將透過查詢每個輸入模組來讀取和寫入 tensordict。在呼叫TensorDictSequencial例項和函式式模組時,期望引數列表(和緩衝區)被連線成一個單獨的列表。- 引數:
modules (OrderedDict[str, Callable[[TensorDictBase], TensorDictBase]] | List[Callable[[TensorDictBase], TensorDictBase]]) – 按順序排列的可呼叫物件序列,它們接受一個 TensorDictBase 作為輸入並返回一個 TensorDictBase。這些可以是 TensorDictModuleBase 的例項,也可以是任何匹配此簽名的其他函式。請注意,如果使用非 TensorDictModuleBase 的可呼叫物件,其輸入和輸出鍵將不會被跟蹤,因此不會影響 TensorDictSequential 的 in_keys 和 out_keys 屬性。常規的
dict輸入將在必要時轉換為OrderedDict。- 關鍵字引數:
partial_tolerant (bool, optional) – 如果為 True,則輸入 tensordict 可以缺少某些輸入鍵。如果是這樣,唯一將被執行的模組是那些給定存在的鍵就可以執行的模組。此外,如果輸入 tensordict 是一個 lazy stack of tensordicts 並且 partial_tolerant 為
True並且堆疊不包含必需的鍵,那麼 TensorDictSequential 將掃描子 tensordicts 以查詢任何具有必需鍵的 tensordicts。預設為 False。selected_out_keys (iterable of NestedKeys, optional) – 要選擇的 out-keys 列表。如果未提供,則會寫入所有
out_keys。inplace (bool or str, optional) – 如果為 True,則輸入 tensordict 被就地修改。如果為 False,則建立一個新的空的
TensorDict例項。如果為 “empty”,則使用 input.empty()(即,輸出保留型別、裝置和 batch-size)。預設為 None(依賴於子模組)。
注意
TensorDictSequential例項可能有一個很長的輸出鍵列表,並且人們可能希望在執行後刪除其中一些以提高畫質晰度或節省記憶體。如果是這種情況,可以使用select_out_keys()方法來例項化,或者將 selected_out_keys 傳遞給建構函式。示例
>>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> torch.manual_seed(0) >>> module = TensorDictSequential( ... TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["x+1"]), ... TensorDictModule(nn.Linear(3, 4), in_keys=["x+1"], out_keys=["w*(x+1)+b"]), ... ) >>> # with tensordict input >>> print(module(TensorDict({"x": torch.zeros(3)}, []))) TensorDict( fields={ w*(x+1)+b: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), x+1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), x: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> # with tensor input: returns all the output keys in the order of the modules, ie "x+1" and "w*(x+1)+b" >>> module(x=torch.zeros(3)) (tensor([1., 1., 1.]), tensor([-0.7214, -0.8748, 0.1571, -0.1138], grad_fn=<AddBackward0>)) >>> module(torch.zeros(3)) (tensor([1., 1., 1.]), tensor([-0.7214, -0.8748, 0.1571, -0.1138], grad_fn=<AddBackward0>))
TensorDictSequence 支援函式式、模組化和 vmap 編碼。
示例
>>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import ( ... ProbabilisticTensorDictModule, ... ProbabilisticTensorDictSequential, ... TensorDictModule, ... TensorDictSequential, ... ) >>> from tensordict.nn.distributions import NormalParamExtractor >>> from tensordict.nn.functional_modules import make_functional >>> from torch.distributions import Normal >>> td = TensorDict({"input": torch.randn(3, 4)}, [3,]) >>> net1 = torch.nn.Linear(4, 8) >>> module1 = TensorDictModule(net1, in_keys=["input"], out_keys=["params"]) >>> normal_params = TensorDictModule( ... NormalParamExtractor(), in_keys=["params"], out_keys=["loc", "scale"] ... ) >>> td_module1 = ProbabilisticTensorDictSequential( ... module1, ... normal_params, ... ProbabilisticTensorDictModule( ... in_keys=["loc", "scale"], ... out_keys=["hidden"], ... distribution_class=Normal, ... return_log_prob=True, ... ) ... ) >>> module2 = torch.nn.Linear(4, 8) >>> td_module2 = TensorDictModule( ... module=module2, in_keys=["hidden"], out_keys=["output"] ... ) >>> td_module = TensorDictSequential(td_module1, td_module2) >>> params = TensorDict.from_module(td_module) >>> with params.to_module(td_module): ... _ = td_module(td) >>> print(td) TensorDict( fields={ hidden: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), output: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), params: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)
- 在 vmap 情況下
>>> from torch import vmap >>> params = params.expand(4) >>> def func(td, params): ... with params.to_module(td_module): ... return td_module(td) >>> td_vmap = vmap(func, (None, 0))(td, params) >>> print(td_vmap) TensorDict( fields={ hidden: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), output: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False), params: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 3]), device=None, is_shared=False)
- forward(tensordict: TensorDictBase = None, tensordict_out: tensordict.base.TensorDictBase | None = None, **kwargs: Any) TensorDictBase¶
當 tensordict 引數未設定時,kwargs 用於建立 TensorDict 的例項。
- reset_out_keys()¶
將
out_keys屬性重置為其原始值。返回: 相同的模組,但
out_keys值已重置。示例
>>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> import torch >>> mod = TensorDictModule(lambda x, y: (x+2, y+2), in_keys=["a", "b"], out_keys=["c", "d"]) >>> mod.select_out_keys("d") >>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []) >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> mod.reset_out_keys() >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
- select_out_keys(*selected_out_keys) TensorDictSequential¶
選擇將在輸出 tensordict 中找到的鍵。
當一個人想丟棄複雜圖中的中間鍵,或者當這些鍵的存在可能觸發意外行為時,這很有用。
原始
out_keys仍然可以透過module.out_keys_source訪問。- 引數:
*out_keys (字串序列 或 字串元組) – 應在輸出 tensordict 中找到的 out_keys。
返回: 相同的模組,以就地修改方式返回,並更新了
out_keys。最簡單的用法是配合
TensorDictModule示例
>>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> import torch >>> mod = TensorDictModule(lambda x, y: (x+2, y+2), in_keys=["a", "b"], out_keys=["c", "d"]) >>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []) >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> mod.select_out_keys("d") >>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []) >>> mod(td) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
此功能也將適用於分派的引數: .. rubric:: 示例
>>> mod(torch.zeros(()), torch.ones(())) tensor(2.)
此更改將就地發生(即返回相同的模組,但具有更新的 out_keys 列表)。可以使用
TensorDictModuleBase.reset_out_keys()方法撤銷此操作。示例
>>> mod.reset_out_keys() >>> mod(TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, [])) TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
這也將適用於其他類,例如 Sequential: .. rubric:: 示例
>>> from tensordict.nn import TensorDictSequential >>> seq = TensorDictSequential( ... TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["y"]), ... TensorDictModule(lambda x: x+1, in_keys=["y"], out_keys=["z"]), ... ) >>> td = TensorDict({"x": torch.zeros(())}, []) >>> seq(td) TensorDict( fields={ x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), y: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> seq.select_out_keys("z") >>> td = TensorDict({"x": torch.zeros(())}, []) >>> seq(td) TensorDict( fields={ x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)
- select_subsequence(in_keys: Optional[Iterable[NestedKey]] = None, out_keys: Optional[Iterable[NestedKey]] = None) TensorDictSequential¶
返回一個新的 TensorDictSequential,其中只包含計算給定輸出鍵與給定輸入鍵所需的模組。
- 引數:
in_keys – 我們想要選擇的子序列的輸入鍵。所有不在
in_keys中的鍵將被視為不相關,並且僅以這些鍵作為輸入的模組將被丟棄。生成的順序模組將遵循“所有模組的輸出將受到in_keys中任何鍵的不同值的影響”的模式。如果未提供,則假定為模組的in_keys。out_keys – 我們想要選擇的子序列的輸出鍵。只有對於獲取
out_keys所必需的模組才會在生成的序列中找到。生成的順序模組將遵循“所有模組都將條件化out_keys條目值。”的模式。如果未提供,則假定為模組的out_keys。
- 返回:
一個包含僅根據給定輸入和輸出鍵所需的模組的新 TensorDictSequential。
示例
>>> from tensordict.nn import TensorDictSequential as Seq, TensorDictModule as Mod >>> idn = lambda x: x >>> module = Seq( ... Mod(idn, in_keys=["a"], out_keys=["b"]), ... Mod(idn, in_keys=["b"], out_keys=["c"]), ... Mod(idn, in_keys=["c"], out_keys=["d"]), ... Mod(idn, in_keys=["a"], out_keys=["e"]), ... ) >>> # select all modules whose output depend on "a" >>> module.select_subsequence(in_keys=["a"]) TensorDictSequential( module=ModuleList( (0): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['a'], out_keys=['b']) (1): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['b'], out_keys=['c']) (2): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['c'], out_keys=['d']) (3): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['a'], out_keys=['e']) ), device=cpu, in_keys=['a'], out_keys=['b', 'c', 'd', 'e']) >>> # select all modules whose output depend on "c" >>> module.select_subsequence(in_keys=["c"]) TensorDictSequential( module=ModuleList( (0): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['c'], out_keys=['d']) ), device=cpu, in_keys=['c'], out_keys=['d']) >>> # select all modules that affect the value of "c" >>> module.select_subsequence(out_keys=["c"]) TensorDictSequential( module=ModuleList( (0): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['a'], out_keys=['b']) (1): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['b'], out_keys=['c']) ), device=cpu, in_keys=['a'], out_keys=['b', 'c']) >>> # select all modules that affect the value of "e" >>> module.select_subsequence(out_keys=["e"]) TensorDictSequential( module=ModuleList( (0): TensorDictModule( module=<function <lambda> at 0x126ed1ca0>, device=cpu, in_keys=['a'], out_keys=['e']) ), device=cpu, in_keys=['a'], out_keys=['e'])
此方法會傳播到巢狀的序列
>>> module = Seq( ... Seq( ... Mod(idn, in_keys=["a"], out_keys=["b"]), ... Mod(idn, in_keys=["b"], out_keys=["c"]), ... ), ... Seq( ... Mod(idn, in_keys=["b"], out_keys=["d"]), ... Mod(idn, in_keys=["d"], out_keys=["e"]), ... ), ... ) >>> # select submodules whose output will be affected by a change in "b" or "d" AND which output is "e" >>> module.select_subsequence(in_keys=["b", "d"], out_keys=["e"]) TensorDictSequential( module=ModuleList( (0): TensorDictSequential( module=ModuleList( (0): TensorDictModule( module=<function <lambda> at 0x129efae50>, device=cpu, in_keys=['b'], out_keys=['d']) (1): TensorDictModule( module=<function <lambda> at 0x129efae50>, device=cpu, in_keys=['d'], out_keys=['e']) ), device=cpu, in_keys=['b'], out_keys=['d', 'e']) ), device=cpu, in_keys=['b'], out_keys=['d', 'e'])
inplace 引數允許對輸出型別進行精細控制,例如允許將計算圖的結果寫入輸入物件中,而無需跟蹤中間張量。
示例
>>> import torch >>> from tensordict import TensorClass >>> from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq >>> >>> class MyClass(TensorClass): ... input: torch.Tensor ... output: torch.Tensor | None = None >>> >>> obj = MyClass(torch.randn(2, 3), batch_size=(2,)) >>> >>> model = Seq( ... Mod( ... lambda x: (x + 1, x - 1), ... in_keys=["input"], ... out_keys=[("intermediate", "0"), ("intermediate", "1")], ... inplace=False ... ), ... Mod( ... lambda y0, y1: y0 * y1, ... in_keys=[("intermediate", "0"), ("intermediate", "1")], ... out_keys=["output"], ... inplace=False ... ), ... inplace=True, ) >>> print(model(obj)) MyClass( input=Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), output=Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), output=None, batch_size=torch.Size([2]), device=None, is_shared=False)