評價此頁

torch.func.hessian#

torch.func.hessian(func, argnums=0)[原始碼]#

透過前向-反向策略,計算 func 相對於索引為 argnum 的引數(們)的 Hessian。

正向-逆向策略(組合 jacfwd(jacrev(func)))是獲得良好效能的預設選擇。也可以透過 jacfwd()jacrev() 的其他組合來計算 Hessian,例如 jacfwd(jacfwd(func))jacrev(jacrev(func))

引數
  • func (function) – A Python function that takes one or more arguments, one of which must be a Tensor, and returns one or more Tensors

  • argnums (intTuple[int]) – 可選,整數或整數元組,指定要計算 Hessian 的引數。預設為 0。

返回

返回一個函式,該函式接受與 func 相同的輸入,並返回 func 相對於 argnums 指定的一個或多個引數的 Hessian。

注意

您可能會看到此 API 報錯“forward-mode AD not implemented for operator X”。如果是這種情況,請提交 bug 報告,我們將優先處理。另一種方法是使用 jacrev(jacrev(func)),它具有更好的運算子覆蓋範圍。

基本用法,對於 R^N -> R^1 函式,得到一個 N x N 的 Hessian

>>> from torch.func import hessian
>>> def f(x):
>>>   return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hess = hessian(f)(x)  # equivalent to jacfwd(jacrev(f))(x)
>>> assert torch.allclose(hess, torch.diag(-x.sin()))