注意
轉到末尾 下載完整的示例程式碼。
神經切線核#
創建於: 2023年3月15日 | 最後更新: 2025年9月19日 | 最後驗證: 未驗證
神經切線核 (NTK) 是一個描述 神經網路在訓練過程中如何演變 的核。近年來,圍繞它進行了大量的研究 。本教程受 JAX 中 NTK 實現的啟發 (請參閱 Fast Finite Width Neural Tangent Kernel 以獲取詳細資訊),演示瞭如何使用 PyTorch 的可組合函式變換 torch.func 輕鬆計算此量。
注意
本教程需要 PyTorch 2.6.0 或更高版本。
設定#
首先,進行一些設定。讓我們定義一個簡單的 CNN,我們希望計算其 NTK。
import torch
import torch.nn as nn
from torch.func import functional_call, vmap, vjp, jvp, jacrev
if torch.accelerator.is_available() and torch.accelerator.device_count() > 0:
device = torch.accelerator.current_accelerator()
else:
device = torch.device("cpu")
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, (3, 3))
self.conv2 = nn.Conv2d(32, 32, (3, 3))
self.conv3 = nn.Conv2d(32, 32, (3, 3))
self.fc = nn.Linear(21632, 10)
def forward(self, x):
x = self.conv1(x)
x = x.relu()
x = self.conv2(x)
x = x.relu()
x = self.conv3(x)
x = x.flatten(1)
x = self.fc(x)
return x
讓我們生成一些隨機資料
x_train = torch.randn(20, 3, 32, 32, device=device)
x_test = torch.randn(5, 3, 32, 32, device=device)
建立模型的函式版本#
torch.func 變換作用於函式。特別是,為了計算 NTK,我們需要一個接受模型引數和單個輸入 (而不是輸入批次!) 並返回單個輸出的函式。
我們將使用 torch.func.functional_call,它允許我們使用不同的引數/緩衝區呼叫 nn.Module,以幫助完成第一步。
請記住,模型最初是為接受輸入資料點批次而編寫的。在我們的 CNN 示例中,沒有批次間操作。也就是說,批次中的每個資料點都獨立於其他資料點。基於此假設,我們可以輕鬆生成一個在單個數據點上評估模型的函式
net = CNN().to(device)
# Detaching the parameters because we won't be calling Tensor.backward().
params = {k: v.detach() for k, v in net.named_parameters()}
def fnet_single(params, x):
return functional_call(net, params, (x.unsqueeze(0),)).squeeze(0)
計算 NTK:方法 1 (雅可比收縮)#
我們已準備好計算經驗 NTK。兩個資料點 \(x_1\) 和 \(x_2\) 的經驗 NTK 定義為在 \(x_1\) 處評估的模型雅可比與在 \(x_2\) 處評估的模型雅可比之間的矩陣乘積
在 \(x_1\) 是資料點批次且 \(x_2\) 是資料點批次的批次情況下,我們想要 \(x_1\) 和 \(x_2\) 中所有資料點組合的雅可比之間的矩陣乘積。
第一種方法包括執行此操作——計算兩個雅可比,然後收縮它們。以下是在批次情況下計算 NTK 的方法
def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):
# Compute J(x1)
jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
jac1 = jac1.values()
jac1 = [j.flatten(2) for j in jac1]
# Compute J(x2)
jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
jac2 = jac2.values()
jac2 = [j.flatten(2) for j in jac2]
# Compute J(x1) @ J(x2).T
result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])
result = result.sum(0)
return result
result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)
print(result.shape)
torch.Size([20, 5, 10, 10])
在某些情況下,您可能只想要該數量的對角線或跡,尤其是在您事先知道網路架構會產生 NTK,其中非對角線元素可以近似為零的情況下。可以輕鬆地調整上述函式來執行此操作
def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'):
# Compute J(x1)
jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
jac1 = jac1.values()
jac1 = [j.flatten(2) for j in jac1]
# Compute J(x2)
jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
jac2 = jac2.values()
jac2 = [j.flatten(2) for j in jac2]
# Compute J(x1) @ J(x2).T
einsum_expr = None
if compute == 'full':
einsum_expr = 'Naf,Mbf->NMab'
elif compute == 'trace':
einsum_expr = 'Naf,Maf->NM'
elif compute == 'diagonal':
einsum_expr = 'Naf,Maf->NMa'
else:
assert False
result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])
result = result.sum(0)
return result
result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')
print(result.shape)
torch.Size([20, 5])
此方法的漸近時間複雜度為 \(N O [FP]\) (計算雅可比的時間) + \(N^2 O^2 P\) (收縮雅可比的時間),其中 \(N\) 是 \(x_1\) 和 \(x_2\) 的批次大小,\(O\) 是模型的輸出大小,\(P\) 是引數總數,\([FP]\) 是透過模型進行單次前向傳播的成本。有關詳細資訊,請參閱 Fast Finite Width Neural Tangent Kernel 的第 3.2 節。
計算 NTK:方法 2 (NTK-向量積)#
接下來我們將討論一種使用 NTK-向量積計算 NTK 的方法。
此方法將 NTK 重構為一系列 NTK-向量積,這些積應用於大小為 \(O\times O\) 的單位矩陣 \(I_O\) 的列 (其中 \(O\) 是模型的輸出大小)
其中 \(e_o\in \mathbb{R}^O\) 是單位矩陣 \(I_O\) 的列向量。
令 \(\textrm{vjp}_o = J_{net}^T(x_2) e_o\)。我們可以使用向量-雅可比積來計算它。
現在,考慮 \(J_{net}(x_1) \textrm{vjp}_o\)。這是一個雅可比-向量積!
最後,我們可以使用
vmap來並行執行上述計算,遍歷 \(I_O\) 的所有列 \(e_o\)。
這表明我們可以結合使用反向模式 AD (計算向量-雅可比積) 和前向模式 AD (計算雅可比-向量積) 來計算 NTK。
讓我們來實現它
def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):
def get_ntk(x1, x2):
def func_x1(params):
return func(params, x1)
def func_x2(params):
return func(params, x2)
output, vjp_fn = vjp(func_x1, params)
def get_ntk_slice(vec):
# This computes ``vec @ J(x2).T``
# `vec` is some unit vector (a single slice of the Identity matrix)
vjps = vjp_fn(vec)
# This computes ``J(X1) @ vjps``
_, jvps = jvp(func_x2, (params,), vjps)
return jvps
# Here's our identity matrix
basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1)
return vmap(get_ntk_slice)(basis)
# ``get_ntk(x1, x2)`` computes the NTK for a single data point x1, x2
# Since the x1, x2 inputs to ``empirical_ntk_ntk_vps`` are batched,
# we actually wish to compute the NTK between every pair of data points
# between {x1} and {x2}. That's what the ``vmaps`` here do.
result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)
if compute == 'full':
return result
if compute == 'trace':
return torch.einsum('NMKK->NM', result)
if compute == 'diagonal':
return torch.einsum('NMKK->NMK', result)
# Disable TensorFloat-32 for convolutions on Ampere+ GPUs to sacrifice performance in favor of accuracy
with torch.backends.cudnn.flags(allow_tf32=False):
result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)
result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)
assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)
/usr/local/lib/python3.10/dist-packages/torch/backends/cudnn/__init__.py:145: UserWarning:
Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.com.tw/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)
我們為 empirical_ntk_ntk_vps 編寫的程式碼看起來直接翻譯自上面的數學!這展示了函式變換的強大之處:如果您只使用 torch.autograd.grad,很難編寫一個高效的版本。
此方法的漸近時間複雜度為 \(N^2 O [FP]\),其中 \(N\) 是 \(x_1\) 和 \(x_2\) 的批次大小,\(O\) 是模型的輸出大小,\([FP]\) 是透過模型進行單次前向傳播的成本。因此,此方法比方法 1 (雅可比收縮) 執行更多的網路前向傳播 ( \(N^2 O\) 而不是 \(N O\) ),但完全避免了收縮成本 (沒有 \(N^2 O^2 P\) 項,其中 \(P\) 是模型引數總數)。因此,當 \(O P\) 相對於 \([FP]\) 很大時,此方法更可取,例如具有許多輸出 \(O\) 的全連線 (非卷積) 模型。記憶體方面,兩種方法應該相當。有關詳細資訊,請參閱 Fast Finite Width Neural Tangent Kernel 的第 3.3 節。
指令碼總執行時間: (0 分鐘 0.825 秒)