評價此頁

torch.func.jvp#

torch.func.jvp(func, primals, tangents, *, strict=False, has_aux=False)[source]#

代表 Jacobian-vector product,返回一個元組,其中包含 func(*primals) 的輸出以及在 primals 處計算的“func 的 Jacobian”與 tangents 的乘積。這也被稱為前向模式自動微分。

引數
  • func (function) – A Python function that takes one or more arguments, one of which must be a Tensor, and returns one or more Tensors

  • primals (Tensors) – 傳遞給 func 的位置引數,所有這些引數都必須是 Tensor。返回的函式還將計算相對於這些引數的導數。

  • tangents (Tensors) – 用於計算 Jacobian-vector product 的“向量”。其結構和大小必須與 func 的輸入相同。

  • has_aux (bool) – 一個標誌,指示 func 返回一個 (output, aux) 元組,其中第一個元素是要求導函式的輸出,第二個元素是其他不會被求導的輔助物件。預設為 False。

返回

返回一個 (output, jvp_out) 元組,包含 funcprimals 處計算的輸出以及 Jacobian-vector product。如果 has_aux True,則返回一個 (output, jvp_out, aux) 元組。

注意

您可能會看到此 API 報錯“forward-mode AD not implemented for operator X”。如果出現這種情況,請提交一個 bug 報告,我們將優先處理。

當您希望計算函式 R^1 -> R^N 的梯度時,jvp 非常有用。

>>> from torch.func import jvp
>>> x = torch.randn([])
>>> f = lambda x: x * torch.tensor([1.0, 2.0, 3])
>>> value, grad = jvp(f, (x,), (torch.tensor(1.0),))
>>> assert torch.allclose(value, f(x))
>>> assert torch.allclose(grad, torch.tensor([1.0, 2, 3]))

jvp() 可以透過為每個輸入傳遞對應的 tangents 來支援具有多個輸入的函式。

>>> from torch.func import jvp
>>> x = torch.randn(5)
>>> y = torch.randn(5)
>>> f = lambda x, y: (x * y)
>>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
>>> assert torch.allclose(output, x + y)