快捷方式

tensordict.nn.dispatch

class tensordict.nn.dispatch(separator='_', source='in_keys', dest='out_keys', auto_batch_size: bool = True)

允許一個期望 TensorDict 的函式使用關鍵字引數進行呼叫。

dispatch() 必須用在具有 in_keys(或由 source 關鍵字引數指定的其他鍵源)和 out_keys(或另一個 dest 鍵列表)屬性的模組內部,這些屬性指示了要從 tensordict 中讀取和寫入的鍵。被包裝的函式還應該有一個 tensordict 作為首個引數。

生成的函式將返回一個單一的張量(如果 out_keys 中只有一個元素),否則將返回一個根據模組的 out_keys 排序的元組。

dispatch() 可以作為方法使用,也可以作為類使用,以傳遞額外的引數。

引數:
  • separator (str, optional) – 用於組合子鍵的、當 in_keys 是字串元組時的分隔符。預設為 "_"

  • source (str or list of keys, optional) – 如果提供了一個字串,它將指向包含要使用的輸入鍵列表的模組屬性。如果提供了一個列表,它將包含用作模組輸入的鍵。預設為 "in_keys",這是 TensorDictModule 輸入鍵列表的屬性名稱。

  • dest (str or list of keys, optional) – 如果提供了一個字串,它將指向包含要使用的輸出鍵列表的模組屬性。如果提供了一個列表,它將包含用作模組輸出的鍵。預設為 "out_keys",這是 TensorDictModule 輸出鍵列表的屬性名稱。

  • auto_batch_size (bool, optional) – 如果為 True,則輸入 tensordict 的批次大小將自動確定為所有輸入張量之間公共維度的最大數量。預設為 True

示例

>>> class MyModule(nn.Module):
...     in_keys = ["a"]
...     out_keys = ["b"]
...
...     @dispatch
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a'] + 1
...         return tensordict
...
>>> module = MyModule()
>>> b = module(a=torch.zeros(1, 2))
>>> assert (b == 1).all()
>>> # equivalently
>>> class MyModule(nn.Module):
...     keys_in = ["a"]
...     keys_out = ["b"]
...
...     @dispatch(source="keys_in", dest="keys_out")
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a'] + 1
...         return tensordict
...
>>> module = MyModule()
>>> b = module(a=torch.zeros(1, 2))
>>> assert (b == 1).all()
>>> # or this
>>> class MyModule(nn.Module):
...     @dispatch(source=["a"], dest=["b"])
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a'] + 1
...         return tensordict
...
>>> module = MyModule()
>>> b = module(a=torch.zeros(1, 2))
>>> assert (b == 1).all()

dispatch_kwargs() 也可以與巢狀鍵一起使用,使用預設的 "_" 分隔符。

示例

>>> class MyModuleNest(nn.Module):
...     in_keys = [("a", "c")]
...     out_keys = ["b"]
...
...     @dispatch
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a', 'c'] + 1
...         return tensordict
...
>>> module = MyModuleNest()
>>> b, = module(a_c=torch.zeros(1, 2))
>>> assert (b == 1).all()

如果想要其他分隔符,可以在建構函式中使用 separator 引數來指定。

示例

>>> class MyModuleNest(nn.Module):
...     in_keys = [("a", "c")]
...     out_keys = ["b"]
...
...     @dispatch(separator="sep")
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a', 'c'] + 1
...         return tensordict
...
>>> module = MyModuleNest()
>>> b, = module(asepc=torch.zeros(1, 2))
>>> assert (b == 1).all()

由於輸入鍵是一個排序後的字串序列,dispatch() 也可以與無名引數一起使用,此時順序必須與輸入鍵的順序匹配。

注意

如果第一個引數是 TensorDictBase 例項,則假定 __未使用__ dispatch,並且該 tensordict 包含執行透過模組所需的所有資訊。換句話說,不能使用模組輸入的第一個鍵指向一個 tensordict 例項來分解一個 tensordict。一般來說,更傾向於使用 dispatch() 處理僅包含葉子節點的 tensordict。

示例

>>> class MyModuleNest(nn.Module):
...     in_keys = [("a", "c"), "d"]
...     out_keys = ["b"]
...
...     @dispatch
...     def forward(self, tensordict):
...         tensordict['b'] = tensordict['a', 'c'] + tensordict["d"]
...         return tensordict
...
>>> module = MyModuleNest()
>>> b, = module(torch.zeros(1, 2), d=torch.ones(1, 2))  # works
>>> assert (b == 1).all()
>>> b, = module(torch.zeros(1, 2), torch.ones(1, 2))  # works
>>> assert (b == 1).all()
>>> try:
...     b, = module(torch.zeros(1, 2), a_c=torch.ones(1, 2))  # fails
... except:
...     print("oopsy!")
...

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源