torch.func.hessian#
- torch.func.hessian(func, argnums=0)[原始碼]#
透過前向-反向策略,計算
func相對於索引為argnum的引數(們)的 Hessian。正向-逆向策略(組合
jacfwd(jacrev(func)))是獲得良好效能的預設選擇。也可以透過jacfwd()和jacrev()的其他組合來計算 Hessian,例如jacfwd(jacfwd(func))或jacrev(jacrev(func))。- 引數
- 返回
返回一個函式,該函式接受與
func相同的輸入,並返回func相對於argnums指定的一個或多個引數的 Hessian。
注意
您可能會看到此 API 報錯“forward-mode AD not implemented for operator X”。如果是這種情況,請提交 bug 報告,我們將優先處理。另一種方法是使用
jacrev(jacrev(func)),它具有更好的運算子覆蓋範圍。基本用法,對於 R^N -> R^1 函式,得到一個 N x N 的 Hessian
>>> from torch.func import hessian >>> def f(x): >>> return x.sin().sum() >>> >>> x = torch.randn(5) >>> hess = hessian(f)(x) # equivalent to jacfwd(jacrev(f))(x) >>> assert torch.allclose(hess, torch.diag(-x.sin()))