評價此頁

torch.func API 參考#

建立日期:2025 年 6 月 11 日 | 最後更新日期:2025 年 6 月 11 日

函式變換#

vmap

vmap 是向量化對映;vmap(func) 返回一個新的函式,該函式將 func 對映到輸入的某個維度上。

grad

grad 運算元有助於計算 func 相對於由 argnums 指定的輸入(們)的梯度。

grad_and_value

返回一個用於計算梯度和原始值(或前向計算)的元組的函式。

vjp

代表向量-雅可比矩陣乘積,返回一個元組,其中包含 func 應用於 primals 的結果,以及一個函式,該函式在給定 cotangents 時,計算 func 相對於 primals 的反向模式雅可比矩陣乘以 cotangents

jvp

代表雅可比矩陣-向量乘積,返回一個元組,其中包含 func(*primals) 的輸出以及“在 primals 處評估的 func 的雅可比矩陣”乘以 tangents

linearize

返回 funcprimals 處的值以及在 primals 處的線性近似。

jacrev

使用反向模式自動微分計算 func 相對於 argnum 索引處的引數(們)的雅可比矩陣。

jacfwd

使用前向模式自動微分計算 func 相對於 argnum 索引處的引數(們)的雅可比矩陣。

hessian

透過前向-反向策略,計算 func 相對於索引為 argnum 的引數(們)的 Hessian。

functionalize

functionalize 是一個變換,可用於從函式中移除(中間)突變和別名,同時保留函式的語義。

用於處理 torch.nn.Modules 的實用程式#

通常,您可以變換呼叫 torch.nn.Module 的函式。例如,以下是計算一個接收三個值並返回三個值的函式的雅可比矩陣的示例。

model = torch.nn.Linear(3, 3)

def f(x):
    return model(x)

x = torch.randn(3)
jacobian = jacrev(f)(x)
assert jacobian.shape == (3, 3)

但是,如果您想執行諸如計算模型引數的雅可比矩陣之類的操作,則需要一種方法來構造一個將引數作為函式輸入的函式。這就是 functional_call() 的用途:它接受一個 nn.Module、轉換後的 parameters 和 Module 前向傳播的輸入。它返回使用替換後的引數執行 Module 前向傳播的值。

以下是我們如何計算引數上的雅可比矩陣。

model = torch.nn.Linear(3, 3)

def f(params, x):
    return torch.func.functional_call(model, params, x)

x = torch.randn(3)
jacobian = jacrev(f)(dict(model.named_parameters()), x)

functional_call

透過替換提供的引數和緩衝區,在模組上執行函式式呼叫。

stack_module_state

為使用 vmap() 進行整合準備一系列 torch.nn.Modules。

replace_all_batch_norm_modules_

就地更新 root,透過將 running_meanrunning_var 設定為 None,併為 root 中的任何 nn.BatchNorm 模組將 track_running_stats 設定為 False。

如果您正在尋找有關修復 BatchNorm 模組的資訊,請遵循此處提供的指導。

除錯實用程式#

debug_unwrap

解開一個函子張量(例如。