評價此頁

torch.func.vjp#

torch.func.vjp(func, *primals, has_aux=False)[原始碼]#

代表向量-雅可比矩陣乘積,返回一個元組,其中包含 func 應用於 primals 的結果,以及一個函式,該函式在給定 cotangents 時,計算 func 相對於 primals 的反向模式雅可比矩陣乘以 cotangents

引數
  • func (Callable) – 一個接受一個或多個引數的 Python 函式。必須返回一個或多個 Tensor。

  • primals (Tensors) – 傳遞給 func 的位置引數,所有這些引數都必須是 Tensor。返回的函式還將計算相對於這些引數的導數。

  • has_aux (bool) – 標誌,表示 func 返回一個 (output, aux) 元組,其中第一個元素是要進行微分的函式的輸出,第二個元素是不會進行微分的其他輔助物件。預設為 False。

返回

返回一個 (output, vjp_fn) 元組,包含 func 應用於 primals 後的輸出,以及一個用於計算 func 相對於所有 primals 的 vjp 的函式。該函式使用傳遞給返回函式的餘切。如果 has_aux True,則返回一個 (output, vjp_fn, aux) 元組。返回的 vjp_fn 函式將返回每個 VJP 的元組。

在簡單情況下使用時,vjp() 的行為與 grad() 相同。

>>> x = torch.randn([5])
>>> f = lambda x: x.sin().sum()
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> grad = vjpfunc(torch.tensor(1.0))[0]
>>> assert torch.allclose(grad, torch.func.grad(f)(x))

然而,vjp() 可以透過為每個輸出傳遞餘切來支援具有多個輸出的函式。

>>> x = torch.randn([5])
>>> f = lambda x: (x.sin(), x.cos())
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> vjps = vjpfunc((torch.ones([5]), torch.ones([5])))
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())

vjp() 甚至可以支援輸出為 Python 結構。

>>> x = torch.randn([5])
>>> f = lambda x: {"first": x.sin(), "second": x.cos()}
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> cotangents = {"first": torch.ones([5]), "second": torch.ones([5])}
>>> vjps = vjpfunc(cotangents)
>>> assert torch.allclose(vjps[0], x.cos() + -x.sin())

vjp() 返回的函式將計算相對於每個 primals 的偏導數。

>>> x, y = torch.randn([5, 4]), torch.randn([4, 5])
>>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y)
>>> cotangents = torch.randn([5, 5])
>>> vjps = vjpfunc(cotangents)
>>> assert len(vjps) == 2
>>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1)))
>>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))

primalsf 的位置引數。所有關鍵字引數都使用其預設值。

>>> x = torch.randn([5])
>>> def f(x, scale=4.):
>>>   return x * scale
>>>
>>> (_, vjpfunc) = torch.func.vjp(f, x)
>>> vjps = vjpfunc(torch.ones_like(x))
>>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.0))

注意

將 PyTorch torch.no_gradvjp 一起使用。情況 1:在函式內部使用 torch.no_grad

>>> def f(x):
>>>     with torch.no_grad():
>>>         c = x ** 2
>>>     return x - c

在這種情況下,vjp(f)(x) 將尊重內部的 torch.no_grad

情況 2:在 torch.no_grad 上下文管理器內使用 vjp

>>> with torch.no_grad():
>>>     vjp(f)(x)

在這種情況下,vjp 將尊重內部的 torch.no_grad,但不會尊重外部的。這是因為 vjp 是一個“函式變換”:其結果不應依賴於 f 外部的上下文管理器的結果。