評價此頁

使用 torch.compile 構建卷積/批歸一化融合器#

作者: Horace He, Will Feng

您將學到什麼
  • 如何向 torch.compile 的模式匹配器註冊自定義融合模式

先決條件
  • PyTorch v2.7.0

注意

此最佳化僅適用於推理模式下的模型(即 model.eval())。然而,torch.compile 的模式匹配系統同時適用於訓練和推理。

首先,我們完成一些匯入(稍後我們將在程式碼中全部使用它們)。

from typing import Type, Dict, Any, Tuple, Iterable
import copy
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

在本教程中,我們將建立一個包含卷積和批歸一化的模型。請注意,此模型包含一些棘手的元件 — — 一些卷積/批歸一化模式隱藏在 Sequential 模組中,並且其中一個 BatchNorm 被包裝在另一個 Module 中。

class WrappedBatchNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.mod = nn.BatchNorm2d(1)
    def forward(self, x):
        return self.mod(x)

class M(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 1, 1)
        self.bn1 = nn.BatchNorm2d(1)
        self.conv2 = nn.Conv2d(1, 1, 1)
        self.nested = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 1, 1),
        )
        self.wrapped = WrappedBatchNorm()

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.conv2(x)
        x = self.nested(x)
        x = self.wrapped(x)
        return x

model = M().to(device)
model.eval()

融合卷積與批歸一化#

嘗試在 PyTorch 中自動融合卷積和批歸一化所面臨的主要挑戰之一是 PyTorch 沒有提供訪問計算圖的便捷方法。torch.compile 透過在編譯期間捕獲計算圖來解決此問題,從而使我們能夠對整個模型應用基於模式的最佳化,包括巢狀在 Sequential 模組中或包裝在自定義模組中的操作。

import torch._inductor.pattern_matcher as pm
from torch._inductor.pattern_matcher import register_replacement

torch.compile 將捕獲模型的圖形表示。在編譯期間,隱藏在 Sequential 容器和包裝的模組中的模組都會被內聯到圖形中,從而可以進行模式匹配和最佳化。

融合卷積與批歸一化#

與某些其他融合不同,卷積與批歸一化的融合不需要任何新的運算子。相反,由於推理期間的批歸一化由逐點加法和乘法組成,因此這些操作可以“烘焙”到前一個卷積的權重中。這使我們能夠完全從模型中刪除批歸一化!有關更多詳細資訊,請閱讀 https://nenadmarkus.com/p/fusing-batchnorm-and-conv/。出於清晰起見,此處的程式碼是從 pytorch/pytorch 複製的。

def fuse_conv_bn_eval(conv, bn):
    """
    Given a conv Module `A` and an batch_norm module `B`, returns a conv
    module `C` such that C(x) == B(A(x)) in inference mode.
    """
    assert(not (conv.training or bn.training)), "Fusion only for eval!"
    fused_conv = copy.deepcopy(conv)

    fused_conv.weight, fused_conv.bias = \
        fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
                             bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)

    return fused_conv

def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
    if conv_b is None:
        conv_b = torch.zeros_like(bn_rm)
    if bn_w is None:
        bn_w = torch.ones_like(bn_rm)
    if bn_b is None:
        bn_b = torch.zeros_like(bn_rm)
    bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)

    conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))
    conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b

    return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b)

使用 torch.compile 進行模式匹配#

現在我們有了融合邏輯,我們需要註冊一個模式,torch.compile 的模式匹配器將在編譯期間識別並替換該模式。

# Define the pattern we want to match: conv2d followed by batch_norm
def conv_bn_pattern(x, conv_weight, conv_bias, bn_mean, bn_var, bn_weight, bn_bias):
    conv_out = torch.nn.functional.conv2d(x, conv_weight, conv_bias)
    bn_out = torch.nn.functional.batch_norm(
        conv_out, bn_mean, bn_var, bn_weight, bn_bias,
        training=False, eps=1e-5
    )
    return bn_out

def conv_bn_replacement(x, conv_weight, conv_bias, bn_mean, bn_var, bn_weight, bn_bias):
    fused_weight, fused_bias = fuse_conv_bn_weights(
        conv_weight, conv_bias, bn_mean, bn_var, 1e-5, bn_weight, bn_bias
    )
    return torch.nn.functional.conv2d(x, fused_weight, fused_bias)

# Example inputs are needed to trace the pattern functions.
# The inputs should match the function signatures of conv_bn_pattern and conv_bn_replacement.
# These are used to trace the pattern functions to create the match template.
# IMPORTANT: The pattern matcher is shape-agnostic! The specific shapes you use here
# don't limit what shapes will be matched - any valid conv2d->batch_norm sequence
# will be matched regardless of channels, kernel size, or spatial dimensions.
# - x: input tensor (batch_size, channels, height, width)
# - conv_weight: (out_channels, in_channels, kernel_h, kernel_w)
# - conv_bias: (out_channels,)
# - bn_mean, bn_var, bn_weight, bn_bias: all have shape (num_features,) matching out_channels
example_inputs = [
    torch.randn(1, 1, 4, 4).to(device),  # x: input tensor
    torch.randn(1, 1, 1, 1).to(device),  # conv_weight: 1 output channel, 1 input channel, 1x1 kernel
    torch.randn(1).to(device),           # conv_bias: 1 output channel
    torch.randn(1).to(device),           # bn_mean: batch norm running mean
    torch.randn(1).to(device),           # bn_var: batch norm running variance
    torch.randn(1).to(device),           # bn_weight: batch norm weight (gamma)
    torch.randn(1).to(device),           # bn_bias: batch norm bias (beta)
]

from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._inductor import config

# Create a pattern matcher pass and register our pattern
patterns = PatternMatcherPass()

register_replacement(
    conv_bn_pattern,
    conv_bn_replacement,
    example_inputs,
    pm.fwd_only,
    patterns,
)

# Create a custom pass function that applies our patterns
def conv_bn_fusion_pass(graph):
    return patterns.apply(graph)

# Set our custom pass in the config
config.post_grad_custom_post_pass = conv_bn_fusion_pass

注意

出於演示目的,我們在此做了一些簡化,例如僅匹配二維卷積。torch.compile 中的模式匹配器可以處理更復雜的模式。

測試我們的融合通道#

我們現在可以在初始的玩具模型上執行此融合通道,並驗證我們的結果是否相同。此外,我們可以打印出融合模型的程式碼,並驗證不再存在批歸一化。

from torch._dynamo.utils import counters

# Clear the counters before compilation
counters.clear()

# Ensure pattern matcher is enabled
config.pattern_matcher = True

fused_model = torch.compile(model, backend="inductor")
inp = torch.randn(5, 1, 1, 1).to(device)

# Run the model to trigger compilation and pattern matching
with torch.no_grad():
    output = fused_model(inp)
    expected = model(inp)
    torch.testing.assert_close(output, expected)

# Check how many patterns were matched
assert counters['inductor']['pattern_matcher_count'] == 3, "Expected 3 conv-bn patterns to be matched"

# Create a model with different shapes than our example_inputs
test_model_diff_shape = nn.Sequential(
    nn.Conv2d(3, 16, 5),
    nn.BatchNorm2d(16),
    nn.ReLU(),
    nn.Conv2d(16, 32, 7),
    nn.BatchNorm2d(32),
).to(device).eval()

counters.clear()
compiled_diff_shape = torch.compile(test_model_diff_shape, backend="inductor")
test_input_diff_shape = torch.randn(1, 3, 28, 28).to(device)
with torch.no_grad():
    compiled_diff_shape(test_input_diff_shape)

# Check how many patterns were matched
assert counters['inductor']['pattern_matcher_count'] == 2, "Expected 2 conv-bn patterns to be matched"

在 ResNet18 上對我們的融合進行基準測試#

我們可以在 ResNet18 等大型模型上測試我們的融合通道,並檢視此通道在多大程度上提高了推理效能。

import torchvision.models as models
import time

rn18 = models.resnet18().to(device)
rn18.eval()

inp = torch.randn(10, 3, 224, 224).to(device)
output = rn18(inp)

def benchmark(model, iters=20):
    with torch.no_grad():
        for _ in range(10):
            model(inp)
        begin = time.time()
        for _ in range(iters):
            model(inp)
        return str(time.time()-begin)

# Benchmark original model
print("Original model time: ", benchmark(rn18))

# Compile with our custom pattern
compiled_with_pattern_matching = torch.compile(rn18, backend="inductor")

# Benchmark compiled model
print("\ntorch.compile (with conv-bn pattern matching and other fusions): ", benchmark(compiled_with_pattern_matching))


############
# Conclusion
# ----------
# As we can see, torch.compile provides a powerful way to implement
# graph transformations and optimizations through pattern matching.
# By registering custom patterns, we can extend torch.compile's
# optimization capabilities to handle domain-specific transformations.
#
# The conv-bn fusion demonstrated here is just one example of what's
# possible with torch.compile's pattern matching system.