torch.func#
建立日期:2025 年 6 月 11 日 | 最後更新日期:2025 年 6 月 11 日
torch.func,之前稱為“functorch”,是 JAX 式的可組合函式變換,適用於 PyTorch。
注意
該庫目前處於 Beta 階段。這意味著這些功能通常都能正常工作(除非另有說明),並且我們(PyTorch 團隊)致力於推進該庫的發展。然而,API 可能會根據使用者反饋而更改,並且我們無法完全覆蓋所有 PyTorch 操作。
如果您對 API 或您希望覆蓋的使用場景有任何建議,請在 GitHub 上提交 issue 或與我們聯絡。我們很樂意瞭解您如何使用該庫。
什麼是可組合函式變換?#
“函式變換”是一種高階函式,它接受一個數值函式並返回一個計算不同量的新函式。
torch.func提供了自動微分變換(grad(f)返回一個計算f的梯度的函式)、向量化/批次化變換(vmap(f)返回一個在輸入批次上計算f的函式)等。這些函式變換可以任意組合。例如,組合
vmap(grad(f))可以計算一個稱為“每樣本梯度”的量,這是標準 PyTorch 目前無法高效計算的。