評價此頁

torch.func#

建立日期:2025 年 6 月 11 日 | 最後更新日期:2025 年 6 月 11 日

torch.func,之前稱為“functorch”,是 JAX 式的可組合函式變換,適用於 PyTorch。

注意

該庫目前處於 Beta 階段。這意味著這些功能通常都能正常工作(除非另有說明),並且我們(PyTorch 團隊)致力於推進該庫的發展。然而,API 可能會根據使用者反饋而更改,並且我們無法完全覆蓋所有 PyTorch 操作。

如果您對 API 或您希望覆蓋的使用場景有任何建議,請在 GitHub 上提交 issue 或與我們聯絡。我們很樂意瞭解您如何使用該庫。

什麼是可組合函式變換?#

  • “函式變換”是一種高階函式,它接受一個數值函式並返回一個計算不同量的新函式。

  • torch.func 提供了自動微分變換(grad(f) 返回一個計算 f 的梯度的函式)、向量化/批次化變換(vmap(f) 返回一個在輸入批次上計算 f 的函式)等。

  • 這些函式變換可以任意組合。例如,組合 vmap(grad(f)) 可以計算一個稱為“每樣本梯度”的量,這是標準 PyTorch 目前無法高效計算的。

為什麼需要可組合函式變換?#

目前在 PyTorch 中有一些難以實現的使用案例

  • 計算逐樣本梯度(或其他逐樣本量)

  • 在單機上執行模型整合

  • 高效地批處理 MAML 內部迴圈中的任務

  • 高效地計算雅可比矩陣和 Hessian 矩陣

  • 高效地計算批處理雅可比矩陣和 Hessian 矩陣

組合 vmap()grad()vjp() 變換使我們能夠表達上述內容,而無需為每一種情況設計單獨的子系統。這種可組合函式變換的思想來源於 JAX 框架