快捷方式

靜態量化

靜態量化是指在推理或生成過程中,為所有輸入使用固定的量化範圍。與動態量化不同,動態量化為每個新的輸入批次動態計算新的量化範圍,靜態量化通常能帶來更高效的計算,但可能以犧牲量化精度為代價,因為我們無法即時適應輸入分佈的變化。

在靜態量化中,這個固定的量化範圍通常在量化模型之前,在類似輸入上進行校準。在校準階段,我們首先將觀察器(observers)插入模型中,以“觀察”要量化的輸入的分佈,然後利用此分佈來決定最終量化模型時使用的尺度(scales)和零點(zero points)。

在本教程中,我們將透過一個示例來演示如何在 torchao 中實現這一點。所有程式碼都可以在這個 示例指令碼 中找到。讓我們從一個簡單的線性模型開始。

import copy
import torch

class ToyLinearModel(torch.nn.Module):
    def __init__(self, m=64, n=32, k=64):
        super().__init__()
        self.linear1 = torch.nn.Linear(m, k, bias=False)
        self.linear2 = torch.nn.Linear(k, n, bias=False)

    def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"):
        return (
            torch.randn(
                batch_size, self.linear1.in_features, dtype=dtype, device=device
            ),
        )

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x

dtype = torch.bfloat16
m = ToyLinearModel().eval().to(dtype).to("cuda")
m = torch.compile(m, mode="max-autotune")

校準階段

torchao 提供了一個簡單的觀察器實現,名為 AffineQuantizedMinMaxObserver,它在校準階段記錄流經觀察器的最小值和最大值。我們歡迎使用者實現自己期望的、更高階的觀察技術,例如依賴於移動平均值或直方圖的技術,這些技術將來可能會被新增到 torchao 中。

from torchao.quantization.granularity import PerAxis, PerTensor
from torchao.quantization.observer import AffineQuantizedMinMaxObserver
from torchao.quantization.quant_primitives import MappingType

# per tensor input activation asymmetric quantization
act_obs = AffineQuantizedMinMaxObserver(
    MappingType.ASYMMETRIC,
    torch.uint8,
    granularity=PerTensor(),
    eps=torch.finfo(torch.float32).eps,
    scale_dtype=torch.float32,
    zero_point_dtype=torch.float32,
)

# per channel weight asymmetric quantization
weight_obs = AffineQuantizedMinMaxObserver(
    MappingType.ASYMMETRIC,
    torch.uint8,
    granularity=PerAxis(axis=0),
    eps=torch.finfo(torch.float32).eps,
    scale_dtype=torch.float32,
    zero_point_dtype=torch.float32,
)

接下來,我們定義一個“觀察後的線性”(ObservedLinear)模組,我們將用它來替換我們的 torch.nn.Linear。這是一個高精度(例如 fp32)的線性模組,其中插入了上述觀察器,用於在校準期間記錄輸入啟用和權重的值。

import torch.nn.functional as F

class ObservedLinear(torch.nn.Linear):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        act_obs: torch.nn.Module,
        weight_obs: torch.nn.Module,
        bias: bool = True,
        device=None,
        dtype=None,
    ):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.act_obs = act_obs
        self.weight_obs = weight_obs

    def forward(self, input: torch.Tensor):
        observed_input = self.act_obs(input)
        observed_weight = self.weight_obs(self.weight)
        return F.linear(observed_input, observed_weight, self.bias)

    @classmethod
    def from_float(cls, float_linear, act_obs, weight_obs):
        observed_linear = cls(
            float_linear.in_features,
            float_linear.out_features,
            act_obs,
            weight_obs,
            False,
            device=float_linear.weight.device,
            dtype=float_linear.weight.dtype,
        )
        observed_linear.weight = float_linear.weight
        observed_linear.bias = float_linear.bias
        return observed_linear

要將這些觀察器實際插入到我們的簡單模型中。

from torchao.quantization.quant_api import (
    _replace_with_custom_fn_if_matches_filter,
)

def insert_observers_(model, act_obs, weight_obs):
    _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear)

    def replacement_fn(m):
        copied_act_obs = copy.deepcopy(act_obs)
        copied_weight_obs = copy.deepcopy(weight_obs)
        return ObservedLinear.from_float(m, copied_act_obs, copied_weight_obs)

    _replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear)

insert_observers_(m, act_obs, weight_obs)

現在我們已準備好校準模型,這將用校準期間記錄的統計資訊填充我們插入的觀察器。我們可以簡單地將一些示例輸入饋送到我們的“觀察後的”模型來完成此操作。

for _ in range(10):
    example_inputs = m.example_inputs(dtype=dtype, device="cuda")
    m(*example_inputs)

量化階段

有多種方法可以實際量化模型。在這裡,我們介紹一種更簡單的替代方法,即定義一個 QuantizedLinear 類,我們將用它來替換我們的 ObservedLinear。定義這個新類並非嚴格必要。對於一種僅使用現有 torch.nn.Linear 的替代方法,請參閱完整的 示例指令碼

from torchao.dtypes import to_affine_quantized_intx_static

class QuantizedLinear(torch.nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        act_obs: torch.nn.Module,
        weight_obs: torch.nn.Module,
        weight: torch.Tensor,
        bias: torch.Tensor,
        target_dtype: torch.dtype,
    ):
        super().__init__()
        self.act_scale, self.act_zero_point = act_obs.calculate_qparams()
        weight_scale, weight_zero_point = weight_obs.calculate_qparams()
        assert weight.dim() == 2
        block_size = (1, weight.shape[1])
        self.target_dtype = target_dtype
        self.bias = bias
        self.qweight = to_affine_quantized_intx_static(
            weight, weight_scale, weight_zero_point, block_size, self.target_dtype
        )

    def forward(self, input: torch.Tensor):
        block_size = input.shape
        qinput = to_affine_quantized_intx_static(
            input,
            self.act_scale,
            self.act_zero_point,
            block_size,
            self.target_dtype,
        )
        return F.linear(qinput, self.qweight, self.bias)

    @classmethod
    def from_observed(cls, observed_linear, target_dtype):
        quantized_linear = cls(
            observed_linear.in_features,
            observed_linear.out_features,
            observed_linear.act_obs,
            observed_linear.weight_obs,
            observed_linear.weight,
            observed_linear.bias,
            target_dtype,
        )
        return quantized_linear

這個線性類在開始時計算輸入啟用和權重的尺度和零點,從而為將來的前向呼叫固定量化範圍。現在,要使用這個線性類實際量化模型,我們可以定義以下配置並將其傳遞給 torchao 的主 quantize_ API。

from dataclasses import dataclass

from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_
from torchao.quantization.transform_module import (
    register_quantize_module_handler,
)

@dataclass
class StaticQuantConfig(AOBaseConfig):
    target_dtype: torch.dtype

@register_quantize_module_handler(StaticQuantConfig)
def _apply_static_quant(
    module: torch.nn.Module,
    config: StaticQuantConfig,
):
    """
    Define a transformation associated with `StaticQuantConfig`.
    This is called by `quantize_`, not by the user directly.
    """
    return QuantizedLinear.from_observed(module, config.target_dtype)

# filter function to identify which modules to swap
is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear)

# perform static quantization
quantize_(m, StaticQuantConfig(torch.uint8), is_observed_linear)

現在,我們將看到模型中的線性層已被替換為我們的 QuantizedLinear 類,具有固定的輸入啟用尺度和固定的量化權重。

>>> m
OptimizedModule(
  (_orig_mod): ToyLinearModel(
    (linear1): QuantizedLinear()
    (linear2): QuantizedLinear()
  )
)
>>> m.linear1.act_scale
tensor([0.0237], device='cuda:0')
>>> m.linear1.qweight
AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=tensor([[142,  31,  42,  ..., 113, 157,  57],
        [ 59, 160,  70,  ...,  23, 150,  67],
        [ 44,  49, 241,  ..., 238,  69, 235],
        ...,
        [228, 255, 201,  ..., 114, 236,  73],
        [ 50,  88,  83,  ..., 109, 209,  92],
        [184, 141,  35,  ..., 224, 110,  66]], device='cuda:0',
       dtype=torch.uint8)... , scale=tensor([0.0009, 0.0010, 0.0009, 0.0010, 0.0009, 0.0010, 0.0010, 0.0010, 0.0010,
        0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
        0.0010, 0.0010, 0.0010, 0.0009, 0.0010, 0.0010, 0.0010, 0.0009, 0.0010,
        0.0009, 0.0010, 0.0010, 0.0010, 0.0009, 0.0009, 0.0009, 0.0010, 0.0009,
        0.0010, 0.0009, 0.0010, 0.0010, 0.0010, 0.0009, 0.0009, 0.0009, 0.0010,
        0.0009, 0.0010, 0.0009, 0.0009, 0.0009, 0.0010, 0.0010, 0.0009, 0.0009,
        0.0010, 0.0009, 0.0010, 0.0010, 0.0009, 0.0009, 0.0009, 0.0009, 0.0009,
        0.0010], device='cuda:0')... , zero_point=tensor([130., 128., 122., 130., 132., 128., 125., 130., 126., 128., 129., 126.,
        128., 128., 128., 128., 129., 127., 130., 125., 128., 133., 126., 126.,
        128., 124., 127., 128., 128., 128., 129., 124., 126., 133., 129., 127.,
        126., 124., 130., 126., 127., 129., 124., 125., 127., 130., 128., 132.,
        128., 129., 128., 129., 131., 132., 127., 135., 126., 130., 124., 136.,
        131., 124., 130., 129.], device='cuda:0')... , _layout=PlainLayout()), block_size=(1, 64), shape=torch.Size([64, 64]), device=cuda:0, dtype=torch.bfloat16, requires_grad=False)

在本教程中,我們透過一個基本示例演示瞭如何在 torchao 中執行整數靜態量化。我們還有一個示例演示瞭如何執行相同的 float8 靜態量化。有關更多詳細資訊,請參閱完整的 示例指令碼

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源