注意
轉到底部 下載完整的示例程式碼。
Per-sample-gradients(逐樣本梯度)#
創建於: 2023年3月15日 | 最後更新: 2025年7月30日 | 最後驗證: 2024年11月5日
這是什麼?#
逐樣本梯度計算是指計算資料批次中每個樣本的梯度。這在差分隱私、元學習和最佳化研究中是一個有用的量。
注意
本教程需要 PyTorch 2.0.0 或更高版本。
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)
# Here's a simple CNN and loss function:
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def loss_fn(predictions, targets):
return F.nll_loss(predictions, targets)
讓我們生成一個虛擬資料批次,並假裝我們在處理 MNIST 資料集。虛擬影像大小為 28x28,我們使用的迷你批次大小為 64。
device = 'cuda'
num_models = 10
batch_size = 64
data = torch.randn(batch_size, 1, 28, 28, device=device)
targets = torch.randint(10, (64,), device=device)
在常規模型訓練中,我們會將迷你批次透過模型進行前向傳播,然後呼叫 `.backward()` 來計算梯度。這將生成整個迷你批次的“平均”梯度。
model = SimpleCNN().to(device=device)
predictions = model(data) # move the entire mini-batch through the model
loss = loss_fn(predictions, targets)
loss.backward() # back propagate the 'average' gradient of this mini-batch
與上述方法相比,逐樣本梯度計算等同於
對資料中的每個單獨樣本,執行一次前向和後向傳播,以獲得一個單獨的(逐樣本)梯度。
def compute_grad(sample, target):
sample = sample.unsqueeze(0) # prepend batch dimension for processing
target = target.unsqueeze(0)
prediction = model(sample)
loss = loss_fn(prediction, target)
return torch.autograd.grad(loss, list(model.parameters()))
def compute_sample_grads(data, targets):
""" manually process each sample with per sample gradient """
sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]
sample_grads = zip(*sample_grads)
sample_grads = [torch.stack(shards) for shards in sample_grads]
return sample_grads
per_sample_grads = compute_sample_grads(data, targets)
sample_grads[0] 是 `model.conv1.weight` 的逐樣本梯度。`model.conv1.weight.shape` 是 `[32, 1, 3, 3]`;請注意,對於批次中的每個樣本,都有一個梯度,總共 64 個。
print(per_sample_grads[0].shape)
torch.Size([64, 32, 1, 3, 3])
逐樣本梯度——高效的方法,使用函式變換#
我們可以透過使用函式變換來高效地計算逐樣本梯度。
`torch.func` 函式變換 API 對函式進行變換。我們的策略是定義一個計算損失的函式,然後應用變換來構建一個計算逐樣本梯度的函式。
我們將使用 `torch.func.functional_call` 函式來將 `nn.Module` 視為一個函式。
首先,讓我們將 `model` 的狀態提取到兩個字典中:引數和緩衝區。我們將對它們進行分離(detach),因為我們不會使用常規的 PyTorch 自動微分(例如 Tensor.backward(), torch.autograd.grad)。
from torch.func import functional_call, vmap, grad
params = {k: v.detach() for k, v in model.named_parameters()}
buffers = {k: v.detach() for k, v in model.named_buffers()}
接下來,讓我們定義一個函式,該函式給定單個輸入而不是輸入批次來計算模型的損失。重要的是,此函式必須接受引數、輸入和目標,因為我們將對它們進行變換。
注意 - 由於模型最初是為處理批次而編寫的,我們將使用 `torch.unsqueeze` 新增一個批次維度。
def compute_loss(params, buffers, sample, target):
batch = sample.unsqueeze(0)
targets = target.unsqueeze(0)
predictions = functional_call(model, (params, buffers), (batch,))
loss = loss_fn(predictions, targets)
return loss
現在,讓我們使用 `grad` 變換來建立一個新的函式,該函式計算相對於 `compute_loss` 第一個引數(即 `params`)的梯度。
ft_compute_grad = grad(compute_loss)
`ft_compute_grad` 函式計算單個(樣本,目標)對的梯度。我們可以使用 `vmap` 來使其計算整個樣本和目標的批次的梯度。注意 `in_dims=(None, None, 0, 0)`,因為我們希望在資料和目標的第 0 維上對映 `ft_compute_grad`,並且對每個對映使用相同的 `params` 和緩衝區。
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))
最後,讓我們使用我們變換後的函式來計算逐樣本梯度。
我們可以快速檢查使用 `grad` 和 `vmap` 的結果是否與單獨手動處理每個結果匹配。
for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads.values()):
assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=1.2e-1, rtol=1e-5)
快速說明:對可以被 `vmap` 變換的函式型別存在一些限制。最適合變換的函式是純函式:一個輸出僅由輸入決定且沒有副作用(例如突變)的函式。`vmap` 無法處理任意 Python 資料結構的突變,但可以處理許多就地(in-place)的 PyTorch 操作。
效能比較#
想了解 `vmap` 的效能如何?
目前在較新的 GPU(如 A100 (Ampere))上獲得了最佳結果,在該示例上我們看到了高達 25 倍的速度提升,但這裡是我們構建機器上的一些結果。
def get_perf(first, first_descriptor, second, second_descriptor):
"""takes torch.benchmark objects and compares delta of second vs first."""
second_res = second.times[0]
first_res = first.times[0]
gain = (first_res-second_res)/first_res
if gain < 0: gain *=-1
final_gain = gain*100
print(f"Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} ")
from torch.utils.benchmark import Timer
without_vmap = Timer(stmt="compute_sample_grads(data, targets)", globals=globals())
with_vmap = Timer(stmt="ft_compute_sample_grad(params, buffers, data, targets)",globals=globals())
no_vmap_timing = without_vmap.timeit(100)
with_vmap_timing = with_vmap.timeit(100)
print(f'Per-sample-grads without vmap {no_vmap_timing}')
print(f'Per-sample-grads with vmap {with_vmap_timing}')
get_perf(with_vmap_timing, "vmap", no_vmap_timing, "no vmap")
Per-sample-grads without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f1e6ed0f040>
compute_sample_grads(data, targets)
65.07 ms
1 measurement, 100 runs , 1 thread
Per-sample-grads with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f1e6f5eabf0>
ft_compute_sample_grad(params, buffers, data, targets)
3.40 ms
1 measurement, 100 runs , 1 thread
Performance delta: 1815.1264 percent improvement with vmap
還有其他最佳化的解決方案(例如 pytorch/opacus 中的)用於在 PyTorch 中計算逐樣本梯度,這些方案也比樸素方法效能更好。但令人高興的是,組合使用 `vmap` 和 `grad` 可以帶來不錯的速度提升。
總的來說,使用 `vmap` 進行向量化應該比在 for 迴圈中執行函式更快,並且與手動批處理具有競爭力。但也有一些例外,例如如果我們沒有為特定操作實現 `vmap` 規則,或者底層核心沒有針對舊硬體 (GPU) 進行最佳化。如果您遇到任何這些情況,請在 GitHub 上提出 issue 告訴我們。
指令碼總執行時間: (0 分鐘 7.904 秒)