torch.func API 參考#
建立日期:2025 年 6 月 11 日 | 最後更新日期:2025 年 6 月 11 日
函式變換#
vmap 是向量化對映; |
|
|
|
返回一個用於計算梯度和原始值(或前向計算)的元組的函式。 |
|
代表向量-雅可比矩陣乘積,返回一個元組,其中包含 |
|
代表雅可比矩陣-向量乘積,返回一個元組,其中包含 func(*primals) 的輸出以及“在 |
|
返回 |
|
使用反向模式自動微分計算 |
|
使用前向模式自動微分計算 |
|
透過前向-反向策略,計算 |
|
functionalize 是一個變換,可用於從函式中移除(中間)突變和別名,同時保留函式的語義。 |
用於處理 torch.nn.Modules 的實用程式#
通常,您可以變換呼叫 torch.nn.Module 的函式。例如,以下是計算一個接收三個值並返回三個值的函式的雅可比矩陣的示例。
model = torch.nn.Linear(3, 3)
def f(x):
return model(x)
x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)
但是,如果您想執行諸如計算模型引數的雅可比矩陣之類的操作,則需要一種方法來構造一個將引數作為函式輸入的函式。這就是 functional_call() 的用途:它接受一個 nn.Module、轉換後的 parameters 和 Module 前向傳播的輸入。它返回使用替換後的引數執行 Module 前向傳播的值。
以下是我們如何計算引數上的雅可比矩陣。
model = torch.nn.Linear(3, 3)
def f(params, x):
return torch.func.functional_call(model, params, x)
x = torch.randn(3)
jacobian = jacrev(f)(dict(model.named_parameters()), x)
透過替換提供的引數和緩衝區,在模組上執行函式式呼叫。 |
|
為使用 |
|
就地更新 |
如果您正在尋找有關修復 BatchNorm 模組的資訊,請遵循此處提供的指導。
除錯實用程式#
解開一個函子張量(例如。 |