注意
轉到末尾 下載完整的示例程式碼。
使用自定義函式融合卷積和批次歸一化#
建立時間:2021年7月22日 | 最後更新:2023年4月18日 | 最後驗證:2024年11月05日
將相鄰的卷積層和批次歸一化層融合在一起通常是一種推理時的最佳化,以提高執行時間。它通常透過完全消除批次歸一化層並更新前一個卷積層的權重和偏置來實現 [0]。然而,這種技術不適用於訓練模型。
在本教程中,我們將展示一種不同的融合這兩層的技術,該技術可以在訓練期間應用。此最佳化的目標不是改進執行時間,而是減少記憶體使用。
此最佳化的基本思想是,卷積和批次歸一化(以及許多其他操作)都需要在正向傳播中儲存輸入的副本以用於反向傳播。對於較大的批次大小,這些儲存的輸入佔用了大部分記憶體使用量,因此能夠為每個卷積-批次歸一化對避免分配另一個輸入張量,可以顯著減少記憶體佔用。
在本教程中,我們透過將卷積和批次歸一化合併為單個層(作為自定義函式)來避免這種額外的分配。在此組合層的正向傳播中,我們按原樣執行正常的卷積和批次歸一化,唯一的區別是我們只儲存卷積的輸入。為了獲得批次歸一化所需的輸入(這對於反向傳播是必需的),我們在反向傳播期間重新計算卷積的正向傳播。
需要注意的是,此最佳化的使用是情境化的。儘管(透過避免儲存一個緩衝區)我們總是減少了在正向傳播結束時分配的記憶體,但在某些情況下,實際分配的峰值記憶體可能不會減少。有關更多詳細資訊,請參閱最後一部分。
為簡單起見,在本教程中,我們為 Conv2D 硬編碼了 bias=False、stride=1、padding=0、dilation=1 和 groups=1。對於 BatchNorm2D,我們硬編碼了 eps=1e-3、momentum=0.1、affine=False 和 track_running_statistics=False。另一個小區別是,我們在計算批次歸一化時,在平方根之外的分母中添加了 epsilon。
[0] https://nenadmarkus.com/p/fusing-batchnorm-and-conv/
卷積的反向傳播公式實現#
實現自定義函式需要我們自己實現反向傳播。在這種情況下,我們需要 Conv2D 和 BatchNorm2D 的反向傳播公式。最終,我們將在統一的反向傳播函式中將它們連結起來,但下面我們首先將它們實現為自己的自定義函式,以便單獨驗證它們的正確性。
import torch
from torch.autograd.function import once_differentiable
import torch.nn.functional as F
def convolution_backward(grad_out, X, weight):
grad_input = F.conv2d(X.transpose(0, 1), grad_out.transpose(0, 1)).transpose(0, 1)
grad_X = F.conv_transpose2d(grad_out, weight)
return grad_X, grad_input
class Conv2D(torch.autograd.Function):
@staticmethod
def forward(ctx, X, weight):
ctx.save_for_backward(X, weight)
return F.conv2d(X, weight)
# Use @once_differentiable by default unless we intend to double backward
@staticmethod
@once_differentiable
def backward(ctx, grad_out):
X, weight = ctx.saved_tensors
return convolution_backward(grad_out, X, weight)
在使用 gradcheck 進行測試時,使用雙精度很重要。
weight = torch.rand(5, 3, 3, 3, requires_grad=True, dtype=torch.double)
X = torch.rand(10, 3, 7, 7, requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(Conv2D.apply, (X, weight))
True
批次歸一化的反向傳播公式實現#
批次歸一化有兩種模式:訓練模式和 eval 模式。在訓練模式下,樣本統計量是輸入的函式。在 eval 模式下,我們使用儲存的執行統計量,它們不是輸入的函式。這使得非訓練模式的反向傳播明顯更簡單。下面我們僅實現並測試訓練模式下的情況。
def unsqueeze_all(t):
# Helper function to ``unsqueeze`` all the dimensions that we reduce over
return t[None, :, None, None]
def batch_norm_backward(grad_out, X, sum, sqrt_var, N, eps):
# We use the formula: ``out = (X - mean(X)) / (sqrt(var(X)) + eps)``
# in batch norm 2D forward. To simplify our derivation, we follow the
# chain rule and compute the gradients as follows before accumulating
# them all into a final grad_input.
# 1) ``grad of out wrt var(X)`` * ``grad of var(X) wrt X``
# 2) ``grad of out wrt mean(X)`` * ``grad of mean(X) wrt X``
# 3) ``grad of out wrt X in the numerator`` * ``grad of X wrt X``
# We then rewrite the formulas to use as few extra buffers as possible
tmp = ((X - unsqueeze_all(sum) / N) * grad_out).sum(dim=(0, 2, 3))
tmp *= -1
d_denom = tmp / (sqrt_var + eps)**2 # ``d_denom = -num / denom**2``
# It is useful to delete tensors when you no longer need them with ``del``
# For example, we could've done ``del tmp`` here because we won't use it later
# In this case, it's not a big difference because ``tmp`` only has size of (C,)
# The important thing is avoid allocating NCHW-sized tensors unnecessarily
d_var = d_denom / (2 * sqrt_var) # ``denom = torch.sqrt(var) + eps``
# Compute ``d_mean_dx`` before allocating the final NCHW-sized grad_input buffer
d_mean_dx = grad_out / unsqueeze_all(sqrt_var + eps)
d_mean_dx = unsqueeze_all(-d_mean_dx.sum(dim=(0, 2, 3)) / N)
# ``d_mean_dx`` has already been reassigned to a C-sized buffer so no need to worry
# ``(1) unbiased_var(x) = ((X - unsqueeze_all(mean))**2).sum(dim=(0, 2, 3)) / (N - 1)``
grad_input = X * unsqueeze_all(d_var * N)
grad_input += unsqueeze_all(-d_var * sum)
grad_input *= 2 / ((N - 1) * N)
# (2) mean (see above)
grad_input += d_mean_dx
# (3) Add 'grad_out / <factor>' without allocating an extra buffer
grad_input *= unsqueeze_all(sqrt_var + eps)
grad_input += grad_out
grad_input /= unsqueeze_all(sqrt_var + eps) # ``sqrt_var + eps > 0!``
return grad_input
class BatchNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, X, eps=1e-3):
# Don't save ``keepdim`` values for backward
sum = X.sum(dim=(0, 2, 3))
var = X.var(unbiased=True, dim=(0, 2, 3))
N = X.numel() / X.size(1)
sqrt_var = torch.sqrt(var)
ctx.save_for_backward(X)
ctx.eps = eps
ctx.sum = sum
ctx.N = N
ctx.sqrt_var = sqrt_var
mean = sum / N
denom = sqrt_var + eps
out = X - unsqueeze_all(mean)
out /= unsqueeze_all(denom)
return out
@staticmethod
@once_differentiable
def backward(ctx, grad_out):
X, = ctx.saved_tensors
return batch_norm_backward(grad_out, X, ctx.sum, ctx.sqrt_var, ctx.N, ctx.eps)
使用 gradcheck 進行測試
a = torch.rand(1, 2, 3, 4, requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(BatchNorm.apply, (a,), fast_mode=False)
True
融合卷積和批次歸一化#
現在大部分工作已經完成,我們可以將它們組合在一起。請注意,在 (1) 中我們只儲存了一個用於反向傳播的緩衝區,但這同時也意味著我們在 (5) 中重新計算了卷積的正向傳播。另外,請注意在 (2)、(3)、(4) 和 (6) 中,程式碼與上面的示例完全相同。
class FusedConvBN2DFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, X, conv_weight, eps=1e-3):
assert X.ndim == 4 # N, C, H, W
# (1) Only need to save this single buffer for backward!
ctx.save_for_backward(X, conv_weight)
# (2) Exact same Conv2D forward from example above
X = F.conv2d(X, conv_weight)
# (3) Exact same BatchNorm2D forward from example above
sum = X.sum(dim=(0, 2, 3))
var = X.var(unbiased=True, dim=(0, 2, 3))
N = X.numel() / X.size(1)
sqrt_var = torch.sqrt(var)
ctx.eps = eps
ctx.sum = sum
ctx.N = N
ctx.sqrt_var = sqrt_var
mean = sum / N
denom = sqrt_var + eps
# Try to do as many things in-place as possible
# Instead of `out = (X - a) / b`, doing `out = X - a; out /= b`
# avoids allocating one extra NCHW-sized buffer here
out = X - unsqueeze_all(mean)
out /= unsqueeze_all(denom)
return out
@staticmethod
def backward(ctx, grad_out):
X, conv_weight, = ctx.saved_tensors
# (4) Batch norm backward
# (5) We need to recompute conv
X_conv_out = F.conv2d(X, conv_weight)
grad_out = batch_norm_backward(grad_out, X_conv_out, ctx.sum, ctx.sqrt_var,
ctx.N, ctx.eps)
# (6) Conv2d backward
grad_X, grad_input = convolution_backward(grad_out, X, conv_weight)
return grad_X, grad_input, None, None, None, None, None
下一步是將我們的函式式變體包裝到一個有狀態的 nn.Module 中。
import torch.nn as nn
import math
class FusedConvBN(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, exp_avg_factor=0.1,
eps=1e-3, device=None, dtype=None):
super(FusedConvBN, self).__init__()
factory_kwargs = {'device': device, 'dtype': dtype}
# Conv parameters
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
self.conv_weight = nn.Parameter(torch.empty(*weight_shape, **factory_kwargs))
# Batch norm parameters
num_features = out_channels
self.num_features = num_features
self.eps = eps
# Initialize
self.reset_parameters()
def forward(self, X):
return FusedConvBN2DFunction.apply(X, self.conv_weight, self.eps)
def reset_parameters(self) -> None:
nn.init.kaiming_uniform_(self.conv_weight, a=math.sqrt(5))
使用 gradcheck 驗證我們反向傳播公式的正確性。
weight = torch.rand(5, 3, 3, 3, requires_grad=True, dtype=torch.double)
X = torch.rand(2, 3, 4, 4, requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(FusedConvBN2DFunction.apply, (X, weight))
True
測試我們的新層#
使用 FusedConvBN 來訓練一個基本網路。以下程式碼是對此處示例的一些輕微修改:pytorch/examples
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
# Record memory allocated at the end of the forward pass
memory_allocated = [[],[]]
class Net(nn.Module):
def __init__(self, fused=True):
super(Net, self).__init__()
self.fused = fused
if fused:
self.convbn1 = FusedConvBN(1, 32, 3)
self.convbn2 = FusedConvBN(32, 64, 3)
else:
self.conv1 = nn.Conv2d(1, 32, 3, 1, bias=False)
self.bn1 = nn.BatchNorm2d(32, affine=False, track_running_stats=False)
self.conv2 = nn.Conv2d(32, 64, 3, 1, bias=False)
self.bn2 = nn.BatchNorm2d(64, affine=False, track_running_stats=False)
self.fc1 = nn.Linear(9216, 128)
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
if self.fused:
x = self.convbn1(x)
else:
x = self.conv1(x)
x = self.bn1(x)
F.relu_(x)
if self.fused:
x = self.convbn2(x)
else:
x = self.conv2(x)
x = self.bn2(x)
F.relu_(x)
x = F.max_pool2d(x, 2)
F.relu_(x)
x = x.flatten(1)
x = self.fc1(x)
x = self.dropout(x)
F.relu_(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
if fused:
memory_allocated[0].append(torch.cuda.memory_allocated())
else:
memory_allocated[1].append(torch.cuda.memory_allocated())
return output
def train(model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 2 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
# Use inference mode instead of no_grad, for free improved test-time performance
with torch.inference_mode():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, reduction='sum').item()
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
train_kwargs = {'batch_size': 2048}
test_kwargs = {'batch_size': 2048}
if use_cuda:
cuda_kwargs = {'num_workers': 1,
'pin_memory': True,
'shuffle': True}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
dataset1 = datasets.MNIST('../data', train=True, download=True,
transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
0%| | 0.00/9.91M [00:00<?, ?B/s]
100%|██████████| 9.91M/9.91M [00:00<00:00, 138MB/s]
0%| | 0.00/28.9k [00:00<?, ?B/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 29.2MB/s]
0%| | 0.00/1.65M [00:00<?, ?B/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 88.2MB/s]
0%| | 0.00/4.54k [00:00<?, ?B/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 23.9MB/s]
記憶體使用量對比#
如果啟用了 CUDA,則分別列印 fused=True 和 fused=False 的記憶體使用量。例如,在 NVIDIA GeForce RTX 3070、NVIDIA CUDA® Deep Neural Network library (cuDNN) 8.0.5 上執行:融合峰值記憶體:1.56GB,未融合峰值記憶體:2.68GB。
需要注意的是,該模型的峰值記憶體使用量可能因使用的特定 cuDNN 卷積演算法而異。對於較淺的模型,融合模型的峰值記憶體分配可能超過未融合模型的峰值記憶體!這是因為為計算某些 cuDNN 卷積演算法分配的記憶體可能足夠高,可以“隱藏”通常會在反向傳播開始時出現的峰值。
因此,我們也記錄並顯示了在正向傳播結束時分配的記憶體,作為近似值,並演示了我們每對融合的 conv-bn 對確實分配了一個較少的緩衝區。
from statistics import mean
torch.backends.cudnn.enabled = True
if use_cuda:
peak_memory_allocated = []
for fused in (True, False):
torch.manual_seed(123456)
model = Net(fused=fused).to(device)
optimizer = optim.Adadelta(model.parameters(), lr=1.0)
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
for epoch in range(1):
train(model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
scheduler.step()
peak_memory_allocated.append(torch.cuda.max_memory_allocated())
torch.cuda.reset_peak_memory_stats()
print("cuDNN version:", torch.backends.cudnn.version())
print()
print("Peak memory allocated:")
print(f"fused: {peak_memory_allocated[0]/1024**3:.2f}GB, unfused: {peak_memory_allocated[1]/1024**3:.2f}GB")
print("Memory allocated at end of forward pass:")
print(f"fused: {mean(memory_allocated[0])/1024**3:.2f}GB, unfused: {mean(memory_allocated[1])/1024**3:.2f}GB")
Train Epoch: 0 [0/60000 (0%)] Loss: 2.348850
Train Epoch: 0 [4096/60000 (7%)] Loss: 7.906152
Train Epoch: 0 [8192/60000 (13%)] Loss: 3.856514
Train Epoch: 0 [12288/60000 (20%)] Loss: 2.177311
Train Epoch: 0 [16384/60000 (27%)] Loss: 1.874792
Train Epoch: 0 [20480/60000 (33%)] Loss: 1.719687
Train Epoch: 0 [24576/60000 (40%)] Loss: 1.590575
Train Epoch: 0 [28672/60000 (47%)] Loss: 1.528651
Train Epoch: 0 [32768/60000 (53%)] Loss: 1.350549
Train Epoch: 0 [36864/60000 (60%)] Loss: 1.202090
Train Epoch: 0 [40960/60000 (67%)] Loss: 1.015420
Train Epoch: 0 [45056/60000 (73%)] Loss: 1.012655
Train Epoch: 0 [49152/60000 (80%)] Loss: 0.902882
Train Epoch: 0 [53248/60000 (87%)] Loss: 0.814042
Train Epoch: 0 [57344/60000 (93%)] Loss: 0.728736
Test set: Average loss: 0.3424, Accuracy: 9068/10000 (91%)
Train Epoch: 0 [0/60000 (0%)] Loss: 2.349131
Train Epoch: 0 [4096/60000 (7%)] Loss: 7.946014
Train Epoch: 0 [8192/60000 (13%)] Loss: 3.232605
Train Epoch: 0 [12288/60000 (20%)] Loss: 2.597908
Train Epoch: 0 [16384/60000 (27%)] Loss: 1.940867
Train Epoch: 0 [20480/60000 (33%)] Loss: 2.448318
Train Epoch: 0 [24576/60000 (40%)] Loss: 2.038635
Train Epoch: 0 [28672/60000 (47%)] Loss: 1.660587
Train Epoch: 0 [32768/60000 (53%)] Loss: 1.328020
Train Epoch: 0 [36864/60000 (60%)] Loss: 1.166545
Train Epoch: 0 [40960/60000 (67%)] Loss: 1.226512
Train Epoch: 0 [45056/60000 (73%)] Loss: 1.466263
Train Epoch: 0 [49152/60000 (80%)] Loss: 1.013382
Train Epoch: 0 [53248/60000 (87%)] Loss: 0.836892
Train Epoch: 0 [57344/60000 (93%)] Loss: 0.817005
Test set: Average loss: 0.5698, Accuracy: 8279/10000 (83%)
cuDNN version: 91002
Peak memory allocated:
fused: 1.94GB, unfused: 1.50GB
Memory allocated at end of forward pass:
fused: 0.59GB, unfused: 0.96GB
指令碼總執行時間: (0 分 22.557 秒)