評價此頁

torch.func.linearize#

torch.func.linearize(func, *primals)[原始碼]#

返回 funcprimals 處的值以及在 primals 處的線性近似值。

引數
  • func (Callable) – 一個接受一個或多個引數的 Python 函式。

  • primals (Tensors) – func 的位置引數,它們必須都是 Tensor。這些是函式被線性逼近的值。

返回

返回一個 (output, jvp_fn) 元組,其中包含 func 應用於 primals 後的輸出,以及一個用於計算在 primals 處求值的 func 的 jvp 的函式。

返回型別

tuple[Any, Callable]

如果要在 primals 處多次計算 jvp,那麼 linearize 會很有用。然而,為了實現這一點,linearize 會儲存中間計算,並比直接應用 jvp 有更高的記憶體要求。因此,如果所有 tangents 都已知,那麼計算 vmap(jvp) 而不是使用 linearize 可能會更有效。

注意

linearize 會計算 func 兩次。請提交一個 issue 以便實現單次計算。

示例

>>> import torch
>>> from torch.func import linearize
>>> def fn(x):
...     return x.sin()
...
>>> output, jvp_fn = linearize(fn, torch.zeros(3, 3))
>>> jvp_fn(torch.ones(3, 3))
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
>>>