torch.nn.factory_kwargs#
- torch.nn.factory_kwargs(kwargs)[原始碼]#
返回規範化的 factory kwargs 字典。
給定 kwargs,返回一個可以直接傳遞給 factory 函式(如 torch.empty)的規範化 factory kwargs 字典,如果存在不受識別的 kwargs,則會報錯。
此函式使編寫類似以下程式碼的程式碼變得簡單
class MyModule(nn.Module): def __init__(self, **kwargs): factory_kwargs = torch.nn.factory_kwargs(kwargs) self.weight = Parameter(torch.empty(10, **factory_kwargs))
為什麼你應該使用這個函式而不是直接傳遞 kwargs?
1. 此函式會進行錯誤驗證,因此如果存在意外的 kwargs,我們會立即報告錯誤,而不是將其推遲到 factory 呼叫。2. 此函式支援一個特殊的 factory_kwargs 引數,可用於顯式指定一個將被用於 factory 函式的 kwarg,以防其中一個 factory kwarg 與簽名中已存在的引數衝突(例如,在簽名
def f(dtype, **kwargs)中,你可以透過指定dtype來為 factory 函式指定dtype,這與 dtype 引數不同,即f(dtype1, factory_kwargs={"dtype": dtype2}))。