評價此頁

ParametrizationList#

class torch.nn.utils.parametrize.ParametrizationList(modules, original, unsafe=False)[原始碼]#

一個順序容器,用於儲存和管理引數化 torch.nn.Module 的原始引數或緩衝區。

module[tensor_name] 使用 register_parametrization() 進行引數化時,module.parametrizations[tensor_name] 的型別就是 ParametrizationList

如果第一個註冊的引數化具有返回一個張量的 right_inverse 或不具有 right_inverse(在這種情況下,我們假設 right_inverse 是恆等函式),它將以 original 的名稱儲存該張量。如果它有一個返回多個張量的 right_inverse,這些張量將分別註冊為 original0original1,依此類推。

警告

register_parametrization() 會在內部使用此類。此處記錄是為了完整性。使用者不應例項化此類。

引數
  • modules (sequence) – 代表引數化的模組序列

  • original (ParameterTensor) – 被引數化的引數或緩衝區

  • unsafe (bool) – 一個布林標誌,表示引數化是否可能改變張量的 dtype 和形狀。預設為 False。警告:在註冊時,引數化未進行一致性檢查。啟用此標誌需自擔風險。

right_inverse(value)[原始碼]#

按註冊的逆序呼叫引數化列表中的 right_inverse 方法。

然後,如果 right_inverse 輸出一個張量,則將其儲存在 self.original 中;如果輸出多個張量,則分別儲存在 self.original0self.original1 等中。

引數

value (Tensor) – 用於初始化模組的值