tensordict.nn.dispatch¶
- class tensordict.nn.dispatch(separator='_', source='in_keys', dest='out_keys', auto_batch_size: bool = True)¶
允許一個期望 TensorDict 的函式使用關鍵字引數進行呼叫。
dispatch()必須用在具有in_keys(或由source關鍵字引數指定的其他鍵源)和out_keys(或另一個dest鍵列表)屬性的模組內部,這些屬性指示了要從 tensordict 中讀取和寫入的鍵。被包裝的函式還應該有一個tensordict作為首個引數。生成的函式將返回一個單一的張量(如果 out_keys 中只有一個元素),否則將返回一個根據模組的
out_keys排序的元組。dispatch()可以作為方法使用,也可以作為類使用,以傳遞額外的引數。- 引數:
separator (str, optional) – 用於組合子鍵的、當
in_keys是字串元組時的分隔符。預設為"_"。source (str or list of keys, optional) – 如果提供了一個字串,它將指向包含要使用的輸入鍵列表的模組屬性。如果提供了一個列表,它將包含用作模組輸入的鍵。預設為
"in_keys",這是TensorDictModule輸入鍵列表的屬性名稱。dest (str or list of keys, optional) – 如果提供了一個字串,它將指向包含要使用的輸出鍵列表的模組屬性。如果提供了一個列表,它將包含用作模組輸出的鍵。預設為
"out_keys",這是TensorDictModule輸出鍵列表的屬性名稱。auto_batch_size (bool, optional) – 如果為
True,則輸入 tensordict 的批次大小將自動確定為所有輸入張量之間公共維度的最大數量。預設為True。
示例
>>> class MyModule(nn.Module): ... in_keys = ["a"] ... out_keys = ["b"] ... ... @dispatch ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a'] + 1 ... return tensordict ... >>> module = MyModule() >>> b = module(a=torch.zeros(1, 2)) >>> assert (b == 1).all() >>> # equivalently >>> class MyModule(nn.Module): ... keys_in = ["a"] ... keys_out = ["b"] ... ... @dispatch(source="keys_in", dest="keys_out") ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a'] + 1 ... return tensordict ... >>> module = MyModule() >>> b = module(a=torch.zeros(1, 2)) >>> assert (b == 1).all() >>> # or this >>> class MyModule(nn.Module): ... @dispatch(source=["a"], dest=["b"]) ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a'] + 1 ... return tensordict ... >>> module = MyModule() >>> b = module(a=torch.zeros(1, 2)) >>> assert (b == 1).all()
dispatch_kwargs()也可以與巢狀鍵一起使用,使用預設的"_"分隔符。示例
>>> class MyModuleNest(nn.Module): ... in_keys = [("a", "c")] ... out_keys = ["b"] ... ... @dispatch ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a', 'c'] + 1 ... return tensordict ... >>> module = MyModuleNest() >>> b, = module(a_c=torch.zeros(1, 2)) >>> assert (b == 1).all()
如果想要其他分隔符,可以在建構函式中使用
separator引數來指定。示例
>>> class MyModuleNest(nn.Module): ... in_keys = [("a", "c")] ... out_keys = ["b"] ... ... @dispatch(separator="sep") ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a', 'c'] + 1 ... return tensordict ... >>> module = MyModuleNest() >>> b, = module(asepc=torch.zeros(1, 2)) >>> assert (b == 1).all()
由於輸入鍵是一個排序後的字串序列,
dispatch()也可以與無名引數一起使用,此時順序必須與輸入鍵的順序匹配。注意
如果第一個引數是
TensorDictBase例項,則假定 __未使用__ dispatch,並且該 tensordict 包含執行透過模組所需的所有資訊。換句話說,不能使用模組輸入的第一個鍵指向一個 tensordict 例項來分解一個 tensordict。一般來說,更傾向於使用dispatch()處理僅包含葉子節點的 tensordict。示例
>>> class MyModuleNest(nn.Module): ... in_keys = [("a", "c"), "d"] ... out_keys = ["b"] ... ... @dispatch ... def forward(self, tensordict): ... tensordict['b'] = tensordict['a', 'c'] + tensordict["d"] ... return tensordict ... >>> module = MyModuleNest() >>> b, = module(torch.zeros(1, 2), d=torch.ones(1, 2)) # works >>> assert (b == 1).all() >>> b, = module(torch.zeros(1, 2), torch.ones(1, 2)) # works >>> assert (b == 1).all() >>> try: ... b, = module(torch.zeros(1, 2), a_c=torch.ones(1, 2)) # fails ... except: ... print("oopsy!") ...