注意
跳轉到末尾 下載完整的示例程式碼。
Jacobians, Hessians, hvp, vhp, and more: composing function transforms#
Created On: Mar 15, 2023 | Last Updated: Apr 18, 2023 | Last Verified: Nov 05, 2024
計算雅可比矩陣或海森矩陣在許多非傳統深度學習模型中都很有用。使用 PyTorch 的常規自動微分 API(Tensor.backward()、torch.autograd.grad)高效計算這些量非常困難(或令人煩惱)。PyTorch 受 JAX 啟發的 函式變換 API 提供了高效計算各種高階自動微分量的方法。
注意
本教程需要 PyTorch 2.0.0 或更高版本。
計算雅可比矩陣#
import torch
import torch.nn.functional as F
from functools import partial
_ = torch.manual_seed(0)
讓我們從一個我們想計算其雅可比矩陣的函式開始。這是一個具有非線性啟用的簡單線性函式。
讓我們新增一些模擬資料:權重、偏置和特徵向量 x。
D = 16
weight = torch.randn(D, D)
bias = torch.randn(D)
x = torch.randn(D) # feature vector
我們可以將 predict 視為一個將輸入 x 從 \(R^D \to R^D\) 對映的函式。PyTorch Autograd 計算向量-雅可比矩陣乘積。為了計算這個 \(R^D \to R^D\) 函式的完整雅可比矩陣,我們將不得不透過每次使用不同的單位向量來逐行計算它。
def compute_jac(xp):
jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]
for vec in unit_vectors]
return torch.stack(jacobian_rows)
xp = x.clone().requires_grad_()
unit_vectors = torch.eye(D)
jacobian = compute_jac(xp)
print(jacobian.shape)
print(jacobian[0]) # show first row
torch.Size([16, 16])
tensor([-0.5956, -0.6096, -0.1326, -0.2295, 0.4490, 0.3661, -0.1672, -1.1190,
0.1705, -0.6683, 0.1851, 0.1630, 0.0634, 0.6547, 0.5908, -0.1308])
與其逐行計算雅可比矩陣,不如使用 PyTorch 的 torch.vmap 函式變換來消除 for 迴圈並向量化計算。我們不能直接將 vmap 應用於 torch.autograd.grad;相反,PyTorch 提供了一個 torch.func.vjp 變換,它可以與 torch.vmap 組合。
from torch.func import vmap, vjp
_, vjp_fn = vjp(partial(predict, weight, bias), x)
ft_jacobian, = vmap(vjp_fn)(unit_vectors)
# let's confirm both methods compute the same result
assert torch.allclose(ft_jacobian, jacobian)
在後面的教程中,反向模式 AD 和 vmap 的組合將為我們提供每樣本梯度。在本教程中,反向模式 AD 和 vmap 的組合將使我們能夠計算雅可比矩陣!vmap 和自動微分變換的各種組合可以為我們提供不同的有趣量。
PyTorch 提供 torch.func.jacrev 作為一個方便的函式,它執行 vmap-vjp 組合來計算雅可比矩陣。jacrev 接受一個 argnums 引數,指定我們想要計算其雅可比矩陣的引數。
from torch.func import jacrev
ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)
# Confirm by running the following:
assert torch.allclose(ft_jacobian, jacobian)
讓我們比較一下計算雅可比矩陣的兩種方法的效能。函式變換版本快得多(並且隨著輸出數量的增加,速度會更快)。
總的來說,我們期望透過 vmap 進行向量化可以幫助消除開銷並更好地利用硬體。
vmap 透過將外迴圈推入函式的原始操作來實現這種魔力,以獲得更好的效能。
讓我們建立一個快速函式來評估效能並處理微秒和毫秒的測量。
def get_perf(first, first_descriptor, second, second_descriptor):
"""takes torch.benchmark objects and compares delta of second vs first."""
faster = second.times[0]
slower = first.times[0]
gain = (slower-faster)/slower
if gain < 0: gain *=-1
final_gain = gain*100
print(f" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} ")
然後執行效能比較。
from torch.utils.benchmark import Timer
without_vmap = Timer(stmt="compute_jac(xp)", globals=globals())
with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
no_vmap_timer = without_vmap.timeit(500)
with_vmap_timer = with_vmap.timeit(500)
print(no_vmap_timer)
print(with_vmap_timer)
<torch.utils.benchmark.utils.common.Measurement object at 0x7fe0030e4520>
compute_jac(xp)
1.39 ms
1 measurement, 500 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fe002d2b1f0>
jacrev(predict, argnums=2)(weight, bias, x)
397.74 us
1 measurement, 500 runs , 1 thread
讓我們使用我們的 get_perf 函式進行相對效能比較。
get_perf(no_vmap_timer, "without vmap", with_vmap_timer, "vmap")
Performance delta: 71.2962 percent improvement with vmap
此外,反轉問題並說我們想計算模型引數(權重、偏置)的雅可比矩陣而不是輸入的雅可比矩陣也相當容易。
# note the change in input via ``argnums`` parameters of 0,1 to map to weight and bias
ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)
反向模式雅可比矩陣(jacrev) vs 前向模式雅可比矩陣(jacfwd)#
我們提供兩個 API 來計算雅可比矩陣:jacrev 和 jacfwd。
jacrev使用反向模式 AD。正如您上面所見,它是我們的vjp和vmap變換的組合。jacfwd使用前向模式 AD。它被實現為我們jvp和vmap變換的組合。
jacfwd 和 jacrev 可以互相替換,但它們具有不同的效能特徵。
通常的經驗法則是,如果您正在計算 \(R^N \to R^M\) 函式的雅可比矩陣,並且輸出數量遠多於輸入數量(例如,\(M > N\)),則首選 jacfwd,否則使用 jacrev。這個規則也有例外,但一個非嚴格的解釋如下:
在反向模式 AD 中,我們是逐行計算雅可比矩陣,而在前向模式 AD(它計算雅可比矩陣-向量乘積)中,我們是逐列計算。雅可比矩陣有 M 行 N 列,所以如果它在一側更長或更寬,我們可能更傾向於處理更少行或列的方法。
首先,讓我們對輸入多於輸出的情況進行基準測試。
Din = 32
Dout = 2048
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
# remember the general rule about taller vs wider... here we have a taller matrix:
print(weight.shape)
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)
print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
torch.Size([2048, 32])
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fdffe303eb0>
jacfwd(predict, argnums=2)(weight, bias, x)
740.13 us
1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fdffe74b310>
jacrev(predict, argnums=2)(weight, bias, x)
8.50 ms
1 measurement, 500 runs , 1 thread
然後進行相對基準測試。
get_perf(jacfwd_timing, "jacfwd", jacrev_timing, "jacrev", );
Performance delta: 1047.7832 percent improvement with jacrev
現在反過來——輸出(M)多於輸入(N)。
Din = 2048
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())
jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)
print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fe002b33f40>
jacfwd(predict, argnums=2)(weight, bias, x)
6.78 ms
1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7fdffe708490>
jacrev(predict, argnums=2)(weight, bias, x)
483.86 us
1 measurement, 500 runs , 1 thread
以及相對效能比較。
get_perf(jacrev_timing, "jacrev", jacfwd_timing, "jacfwd")
Performance delta: 1300.3639 percent improvement with jacfwd
使用 functorch.hessian 進行海森矩陣計算#
我們提供了一個方便的 API 來計算海森矩陣:torch.func.hessiani。海森矩陣是雅可比矩陣的雅可比矩陣(或偏導數的偏導數,即二階導數)。
這表明我們可以簡單地組合 functorch 雅可比矩陣變換來計算海森矩陣。事實上,在底層,hessian(f) 僅僅是 jacfwd(jacrev(f))。
注意:為了提高效能:根據您的模型,您可能還想使用 jacfwd(jacfwd(f)) 或 jacrev(jacrev(f)) 來計算海森矩陣,並利用上述關於寬矩陣與高矩陣的經驗法則。
from torch.func import hessian
# lets reduce the size in order not to overwhelm Colab. Hessians require
# significant memory:
Din = 512
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)
hess_api = hessian(predict, argnums=2)(weight, bias, x)
hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)
hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)
讓我們驗證一下,無論使用 hessian API 還是 jacfwd(jacfwd()),我們都能獲得相同的結果。
True
批次雅可比矩陣和批次海森矩陣#
在上面的例子中,我們處理的是單個特徵向量。在某些情況下,您可能希望計算輸出批次的雅可比矩陣相對於輸入批次。也就是說,給定形狀為 (B, N) 的輸入批次和一個從 \(R^N \to R^M\) 的函式,我們希望得到形狀為 (B, M, N) 的雅可比矩陣。
最簡單的方法是使用 vmap。
batch_size = 64
Din = 31
Dout = 33
weight = torch.randn(Dout, Din)
print(f"weight shape = {weight.shape}")
bias = torch.randn(Dout)
x = torch.randn(batch_size, Din)
compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))
batch_jacobian0 = compute_batch_jacobian(weight, bias, x)
weight shape = torch.Size([33, 31])
如果您有一個從 (B, N) -> (B, M) 開始的函式,並且確信每個輸入都會產生一個獨立的輸出,那麼有時也可以在不使用 vmap 的情況下完成此操作,透過對輸出求和,然後計算該函式的雅可比矩陣。
def predict_with_output_summed(weight, bias, x):
return predict(weight, bias, x).sum(0)
batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)
assert torch.allclose(batch_jacobian0, batch_jacobian1)
如果您有一個從 \(R^N \to R^M\) 到輸入批次的函式,則將 vmap 與 jacrev 組合以計算批次雅可比矩陣。
最後,批次海森矩陣的計算方法也類似。最簡單的方法是使用 vmap 來批處理海森矩陣計算,但在某些情況下,求和技巧也有效。
compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))
batch_hess = compute_batch_hessian(weight, bias, x)
batch_hess.shape
torch.Size([64, 33, 31, 31])
計算海森矩陣-向量乘積#
計算海森矩陣-向量乘積 (hvp) 的樸素方法是具體化完整的海森矩陣並與向量執行點積。我們可以做得更好:事實證明,我們不必具體化完整的海森矩陣來執行此操作。我們將探討計算海森矩陣-向量乘積的兩種(多種)不同策略:- 反向模式 AD 與反向模式 AD 的組合 - 反向模式 AD 與前向模式 AD 的組合
反向模式 AD 與前向模式 AD 的組合(而不是反向模式與反向模式)通常是計算 hvp 的更節省記憶體的方法,因為前向模式 AD 不需要構建 Autograd 圖並儲存中間結果以進行反向傳播。
以下是一些示例用法。
def f(x):
return x.sin().sum()
x = torch.randn(2048)
tangent = torch.randn(2048)
result = hvp(f, (x,), (tangent,))
如果 PyTorch 的前向 AD 對您的操作沒有覆蓋,那麼我們可以改為組合反向模式 AD 與反向模式 AD。
指令碼總執行時間: (0 分鐘 10.446 秒)