torch.func.linearize#
- torch.func.linearize(func, *primals)[原始碼]#
返回
func在primals處的值以及在primals處的線性近似值。- 引數
func (Callable) – 一個接受一個或多個引數的 Python 函式。
primals (Tensors) –
func的位置引數,它們必須都是 Tensor。這些是函式被線性逼近的值。
- 返回
返回一個
(output, jvp_fn)元組,其中包含func應用於primals後的輸出,以及一個用於計算在primals處求值的func的 jvp 的函式。- 返回型別
如果要在
primals處多次計算 jvp,那麼linearize會很有用。然而,為了實現這一點,linearize 會儲存中間計算,並比直接應用 jvp 有更高的記憶體要求。因此,如果所有tangents都已知,那麼計算 vmap(jvp) 而不是使用 linearize 可能會更有效。注意
linearize會計算func兩次。請提交一個 issue 以便實現單次計算。示例
>>> import torch >>> from torch.func import linearize >>> def fn(x): ... return x.sin() ... >>> output, jvp_fn = linearize(fn, torch.zeros(3, 3)) >>> jvp_fn(torch.ones(3, 3)) tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]) >>>