torch.func.vjp#
- torch.func.vjp(func, *primals, has_aux=False)[原始碼]#
代表向量-雅可比矩陣乘積,返回一個元組,其中包含
func應用於primals的結果,以及一個函式,該函式在給定cotangents時,計算func相對於primals的反向模式雅可比矩陣乘以cotangents。- 引數
func (Callable) – 一個接受一個或多個引數的 Python 函式。必須返回一個或多個 Tensor。
primals (Tensors) – 傳遞給
func的位置引數,所有這些引數都必須是 Tensor。返回的函式還將計算相對於這些引數的導數。has_aux (bool) – 標誌,表示
func返回一個(output, aux)元組,其中第一個元素是要進行微分的函式的輸出,第二個元素是不會進行微分的其他輔助物件。預設為 False。
- 返回
返回一個
(output, vjp_fn)元組,包含func應用於primals後的輸出,以及一個用於計算func相對於所有primals的 vjp 的函式。該函式使用傳遞給返回函式的餘切。如果has_aux 為 True,則返回一個(output, vjp_fn, aux)元組。返回的vjp_fn函式將返回每個 VJP 的元組。
在簡單情況下使用時,
vjp()的行為與grad()相同。>>> x = torch.randn([5]) >>> f = lambda x: x.sin().sum() >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> grad = vjpfunc(torch.tensor(1.0))[0] >>> assert torch.allclose(grad, torch.func.grad(f)(x))
然而,
vjp()可以透過為每個輸出傳遞餘切來支援具有多個輸出的函式。>>> x = torch.randn([5]) >>> f = lambda x: (x.sin(), x.cos()) >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> vjps = vjpfunc((torch.ones([5]), torch.ones([5]))) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
vjp()甚至可以支援輸出為 Python 結構。>>> x = torch.randn([5]) >>> f = lambda x: {"first": x.sin(), "second": x.cos()} >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> cotangents = {"first": torch.ones([5]), "second": torch.ones([5])} >>> vjps = vjpfunc(cotangents) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
由
vjp()返回的函式將計算相對於每個primals的偏導數。>>> x, y = torch.randn([5, 4]), torch.randn([4, 5]) >>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y) >>> cotangents = torch.randn([5, 5]) >>> vjps = vjpfunc(cotangents) >>> assert len(vjps) == 2 >>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1))) >>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))
primals是f的位置引數。所有關鍵字引數都使用其預設值。>>> x = torch.randn([5]) >>> def f(x, scale=4.): >>> return x * scale >>> >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> vjps = vjpfunc(torch.ones_like(x)) >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.0))
注意
將 PyTorch
torch.no_grad與vjp一起使用。情況 1:在函式內部使用torch.no_grad。>>> def f(x): >>> with torch.no_grad(): >>> c = x ** 2 >>> return x - c
在這種情況下,
vjp(f)(x)將尊重內部的torch.no_grad。情況 2:在
torch.no_grad上下文管理器內使用vjp。>>> with torch.no_grad(): >>> vjp(f)(x)
在這種情況下,
vjp將尊重內部的torch.no_grad,但不會尊重外部的。這是因為vjp是一個“函式變換”:其結果不應依賴於f外部的上下文管理器的結果。