評價此頁

torch.nn.factory_kwargs#

torch.nn.factory_kwargs(kwargs)[原始碼]#

返回規範化的 factory kwargs 字典。

給定 kwargs,返回一個可以直接傳遞給 factory 函式(如 torch.empty)的規範化 factory kwargs 字典,如果存在不受識別的 kwargs,則會報錯。

此函式使編寫類似以下程式碼的程式碼變得簡單

class MyModule(nn.Module):
    def __init__(self, **kwargs):
        factory_kwargs = torch.nn.factory_kwargs(kwargs)
        self.weight = Parameter(torch.empty(10, **factory_kwargs))

為什麼你應該使用這個函式而不是直接傳遞 kwargs

1. 此函式會進行錯誤驗證,因此如果存在意外的 kwargs,我們會立即報告錯誤,而不是將其推遲到 factory 呼叫。2. 此函式支援一個特殊的 factory_kwargs 引數,可用於顯式指定一個將被用於 factory 函式的 kwarg,以防其中一個 factory kwarg 與簽名中已存在的引數衝突(例如,在簽名 def f(dtype, **kwargs) 中,你可以透過指定 dtype 來為 factory 函式指定 dtype,這與 dtype 引數不同,即 f(dtype1, factory_kwargs={"dtype": dtype2}))。