• 文件 >
  • 編寫自己的量化張量
快捷方式

編寫您自己的量化張量

torchao 中的量化建立在張量子類(tensor subclasses)的基礎上。它們是 torchao 提供靈活推理和訓練支援的主要擴充套件點,透過低精度計算,同時與 torch.compile、autograd 和分散式原語等重要的 PyTorch 特性相結合。

在本教程中,我們將重點介紹與模組替換(module swaps)相比,利用張量子類的好處,並透過一個簡單的示例來演示如何使用這種方法來表達量化。

什麼是張量子類?

張量子類只是繼承自 torch.Tensor 的類。它們允許使用者在模型中現有的操作之間插入自定義計算邏輯,這樣像 torch.add 這樣的頂級 torch 名稱空間中的函式將繼續無縫工作。

與張量子類方法顯而易見的替代方法是模組替換:例如,只需將模型中的所有 nn.Linear 模組替換為您自定義的 Int8QuantizedLinear 模組。與此方法相比,使用張量子類有幾個重要的好處:

  1. 更精細的整合點。 模組替換在模組級別攔截計算,因此對於依賴 torch 函式或原生模組變體的模型(例如,nn.Linear 的稍作修改的版本)無效。相比之下,由於張量子類在函式/操作級別攔截計算,只要使用相同的函式/操作,我們就可以量化模型。

  2. 更好的可組合性。 使用模組替換組合多個功能很麻煩。例如,組合兩個現有的 Int8QuantizedLinear 和 DistributedLinear 模組需要使用者建立一個另一個線性類,該類複製這些功能。張量子類透過簡單地將一個子類包裝在另一個子類中來繞過此問題。如果外部張量(例如 DTensor)意識到內部張量已被量化,這也可以提供效能優勢,從而可以使用更少的網路和記憶體頻寬執行昂貴的 allgather 操作。

  3. 重用 PyTorch 元件。 使用張量子類來表達量化是很自然的,因為量化張量只是具有不同 dtype 的 torch.Tensors。模型結構保持不變(nn.Linears 仍然是 nn.Linears),因此後續的最佳化傳遞也可以與之前完全相同。


在教程的其餘部分,我們將透過一個示例來演示如何使用這兩種方法來實現量化。有關張量子類的更多閱讀,請參考:

透過模組替換進行量化

我們首先透過一個簡單的示例來實現 int8 僅權重量化,方法是使用模組替換。所有程式碼都可以在這個 示例指令碼 中找到。我們將使用以下函式將 float32 張量量化為 int8 張量:

from typing import Tuple
import torch

def int8_symmetric_quantize(
    fp32_tensor: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Symmetrically quantize the torch.float32 tensor into torch.int8.
    Return a 2-tuple of (quantized value, scale).

    input: dimensions=[M, N], dtype=torch.float32
    output: dimensions=[M, N], dtype=torch.int8
    scale: dimensions=[M, 1], dtype=torch.float32
    """
    quant_min = -128
    quant_max = 127
    min_val = torch.amin(fp32_tensor, dim=[1], keepdim=False)
    max_val = torch.amax(fp32_tensor, dim=[1], keepdim=False)
    min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
    max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
    max_val_pos = torch.max(-min_val_neg, max_val_pos)
    scale = max_val_pos / (float(quant_max - quant_min) / 2)
    scale = scale.view(fp32_tensor.shape[0], -1)
    out = torch.round(fp32_tensor * (1.0 / scale))
    out = torch.clamp(out, quant_min, quant_max).to(torch.int8)
    return out, scale

接下來,我們將建立一個新的 QuantizedLinear 模組,它呼叫此函式來動態量化權重:

class QuantizedLinear(torch.nn.Linear):
    """
    Linear module that performs dynamic and symmetric weight-only
    int8 quantization.
    """
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w_int8, scale = int8_symmetric_quantize(self.weight)
        return torch.matmul(x, w_int8.t().to(x.dtype)) * scale.t()

    @classmethod
    def from_float(cls, mod: torch.nn.Linear):
        new_linear = cls(mod.in_features, mod.out_features, mod.bias)
        new_linear.weight = mod.weight
        return new_linear

然後,唯一剩下的就是將模型中的所有 nn.Linear 模組替換為我們的新 QuantizedLinear。讓我們使用以下玩具模型進行演示:

import copy

class ToyModel(torch.nn.Module):
    def __init__(self, m: int, n: int, k: int):
        super().__init__()
        self.linear1 = torch.nn.Linear(m, n, bias=False)
        self.linear2 = torch.nn.Linear(n, k, bias=False)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x

float_model = ToyModel(64, 128, 32).cuda()
quantized_model = copy.deepcopy(float_model)

# Swap torch.nn.Linear with QuantizedLinear
for name, child in quantized_model.named_children():
    if type(child) == torch.nn.Linear:
        new_linear = QuantizedLinear.from_float(child)
        setattr(quantized_model, name, new_linear)

驗證模型現在是否使用了我們的 QuantizedLinear 模組。該模型現在已準備就緒!

>>> print(float_model)
ToyModel(
  (linear1): Linear(in_features=64, out_features=128, bias=False)
  (linear2): Linear(in_features=128, out_features=32, bias=False)
)

>>> print(quantized_model)
ToyModel(
  (linear1): QuantizedLinear(in_features=64, out_features=128, bias=False)
  (linear2): QuantizedLinear(in_features=128, out_features=32, bias=False)
)

這種簡單方法的一個重要缺點是靈活性。目前這僅適用於原生 PyTorch 模組,但如果模型具有稍作修改的線性模組,例如支援分散式訓練,該怎麼辦?如果模型直接呼叫線性(torch.nn.functional.linear)的功能版本,它也將無效。

此外,假設我們想將此功能與分佈相結合,分佈也透過模組替換實現。除了建立另一個組合了這兩個功能的模組之外,沒有乾淨的方法可以做到這一點。這些限制可以透過張量子類來解決,張量子類是攔截模型中自定義計算(如量化)的一種更優雅的方式。

透過張量子類進行量化

在這裡,我們將使用一個基於 __torch_dispatch__ 的張量子類來重新實現上述量化技術。

張量子類(通常利用 __torch_dispatch__)是 PyTorch 中一個非常強大/靈活的擴充套件點。它們作為擴充套件點有兩個主要目的:

  1. 張量子類允許您覆蓋(幾乎)每個 PyTorch API 的**實現**,並且在很大程度上用於實現其他 PyTorch 產品。

  2. 張量子類允許您將張量資料與附加**元資料進行耦合**。一些示例:

    1. [分散式] 關於張量如何在各個節點之間分片(DTensor文件)的元資料

    2. [量化] 尺度/零點元資料(AffineQuantizedTensor

    3. [不規則性] 關於不規則結構(NestedTensor文件)的元資料

一些關於張量子類的其他資源,供感興趣的讀者參考:

  1. __torch_dispatch__ 文件(連結

  2. 什麼是 __torch_dispatch__(以及為什麼使用它)連結

  3. 使用 __torch_dispatch__ 實現 FlopCounter 和 MemoryTracker 的 Google Colab(連結

話不多說,讓我們開始定義我們最基本的對稱量化張量子類:

class Int8SymmetricTensor(torch.Tensor):
    """
    Our subclass represents a tensor that has been quantized to int8
    It will hold two inner tensors:
      int_data: int8[M, N]
      scale: fp32[M, 1]
    """

    @staticmethod
    @torch._dynamo.disable
    def __new__(cls, int_data: torch.Tensor, scale: torch.Tensor):
        return torch.Tensor._make_wrapper_subclass(
            cls,
            int_data.shape,
            strides=int_data.stride(),
            storage_offset=int_data.storage_offset(),
            dtype=scale.dtype,
            device=int_data.device,
        )

    @torch._dynamo.disable
    def __init__(self, int_data: torch.Tensor, scale: torch.Tensor):
        # inner data expected to be quantized already
        assert int_data.dtype is torch.int8
        # we could do more work to support ndim > 2!
        assert int_data.ndim == 2
        assert scale.ndim == 2
        self.int_data = int_data
        self.scale = scale

    def __tensor_flatten__(self) -> Tuple[List[str], Any]:
        """
        Returns a tuple of:
          names of all inner tensor attributes (two in our case)
          any other additional, non-tensor metadata.

        Needed for PT2 support.
        """
        return ["int_data", "scale"], None

    @classmethod
    def __tensor_unflatten__(cls, tensor_data_dict, extra_metadata, outer_size=None, outer_stride=None):
        """
         __tensor_unflatten__ should effectively undo __tensor_flatten__.

        inputs:
          a dict mapping names of inner tensor attributes back to the tensors
          the constant metadata from __tensor_flatten__
        output:
          a new instance of your subclass

        Needed for PT2 support.
        """
        assert extra_metadata is None
        int_data = tensor_data_dict["int_data"]
        scale = tensor_data_dict["scale"]
        return Int8SymmetricTensor(int_data, scale)

    def __repr__(self):
        return f'Int8SymmetricTensor(int_data={repr(self.int_data)}, scale={repr(self.scale)})'

    @staticmethod
    def from_float(float_tensor):
        """
        Actually performs the symmetric quantization.
        In our simple inference example we will quantize weights "ahead-of-time",
        although later in a training example we can quantize/dequantize
        during model execution, inside of our __torch_dispatch__

        input:
          float32 torch.Tensor
        output:
          Int8SymmetricTensor
        """
        int8_tensor, scale = int8_symmetric_quantize(float_tensor)
        return Int8SymmetricTensor(int8_tensor, scale)

    @classmethod
    def __torch_dispatch__(cls, func, types, args, kwargs):
        """
        Called for each ATen operator that our subclass is passed as an input to.
        We need to define our own implementation for every operator here.
        """
        if kwargs is None:
            kwargs = {}
        if func not in op_implementations_dict:
            raise AssertionError(f'Int8SymmetricTensor does not yet support op: {str(func)}')
        return op_implementations_dict[func](func, *args, **kwargs)


# Convenience function for registering our own implementation
# to every ATen operator in PyTorch
op_implementations_dict = {}
def register_op(ops: List[torch._ops.OpOverload]):
    def impl_decorator(op_impl):
        global op_implementations_dict
        for op in ops:
            op_implementations_dict[op] = op_impl
        return op_impl

    return impl_decorator

在上面的程式碼中,我們做了幾件事:

  1. 定義了一個基本的“包裝器”張量子類——它實際上是一個容器物件,儲存了一些內部資料(特別是兩個張量,對應於我們的 int8 資料和尺度)

  2. 定義了一個 __torch_dispatch__ 實現,對於我們模型對任何子類輸入呼叫的每個 ATen 操作都會呼叫它

  3. (為了支援 PT2)定義了一個 __tensor_flatten__/__tensor_unflatten__ 方法。這是我們的子類與 torch.compile 相容的一些要求中最重要的部分(稍後會詳細介紹)。它有效地告訴 torch.compile 如何將我們的子類“解糖”成其內部元件。

  4. (為了支援 PT2)在構造方法(__new____init__)上添加了 torch._dynamo.disable 裝飾器(稍後會詳細介紹)。

應該實現哪些操作?

PyTorch 擁有相當大的操作表面。與其試圖讓我們的新張量子類實現 100% 的覆蓋,不如讓我們專注於玩具模型所需的那些操作。

但是,我們的模型呼叫了哪些操作,這樣我們才知道首先要實現什麼?暴力方法是反覆執行模型,檢視子類中出現的錯誤操作。更優雅的方法是記錄模型在執行過程中看到的每個操作。這可以透過另一個 LoggingTensor 子類來實現,如此示例所示。

讓我們在下面實現必要的操作:

from torch.utils._python_dispatch import return_and_correct_aliasing

@register_op([torch.ops.aten.mm.default])
def int8_mm(func, x, weight):
    assert isinstance(weight, Int8SymmetricTensor), "Int8SymmetricTensor: matmul currently only supports the weight in low precision, not the input!"
    return torch.mm(x, weight.int_data.to(x.dtype)) * weight.scale

@register_op([
    torch.ops.aten.detach.default,
    torch.ops.aten.t.default,
])
def int8_view_ops(func, *args, **kwargs):
    assert isinstance(args[0], Int8SymmetricTensor)
    out_data = func(args[0].int_data, *args[1:], **kwargs)
    out_scale = func(args[0].scale, *args[1:], **kwargs)
    out = Int8SymmetricTensor(out_data, out_scale)
    return return_and_correct_aliasing(func, args, kwargs, out)

您會很快注意到一件事:我們的模型本身由幾個線性層組成,但我們看到像 aten.taten.mm 這樣的操作擊中了我們的子類。一些背景資訊:

  • 我們在 C++ 中有許多操作分解,它們執行在張量子類“之上”。linear 就是這樣一個操作(分解位於 此處)。

  • 分解可能是好的,因為它們縮小了您作為子類作者需要實現的大 API 表面積。但如果寧願覆蓋“更高層”的操作而不是其分解中的底層操作,它們可能會很麻煩。

  • 如果您寧願在更高層覆蓋某些操作(例如 Linear),您可以使用 __torch_function__示例)來實現。值得注意的是,如果您想要自動微分支援,那麼您在 __torch_function__ 層執行的任何覆蓋都需要以可微分的方式編寫,而您在 __torch_dispatch__ 中執行的任何覆蓋都將自動可微分。

我們的實現中有一些細微之處值得指出:

  1. 您會注意到,我們在 mm 實現中不再需要轉置權重/尺度。這是因為在 aten.mm 操作發生之前,轉置“已經”完成了。

  2. 我們的 aten.mm 實現**不**返回張量子類輸出。從這個意義上說,我們的量化子類的“傳播”在矩陣乘法處結束。這對應於我們的權重是低精度的,但我們需要執行高精度的矩陣乘法的事實。總的來說,子類作者可以自由選擇他們的子類會傳播或不傳播哪些操作。如果您希望模型中的每個函式都被量化(包括所有逐點和歸約操作),您可以編寫您的子類實現來量化每個操作的輸出,並始終返回一個子類。

  3. 我們能夠為 4 個檢視操作重用相同的實現。總的來說,許多操作可能適用於相當通用的實現:解開任何子類輸入,在內部張量上執行底層操作,並將輸出重新包裝到子類中。

    • 然而,是否總是可以重用實現取決於您想做什麼。例如,我們透過對內部資料和內部尺度張量執行相同的轉置來實現我們子類的 transpose(dim0, dim1)。如果我們的尺度和資料張量具有不同數量的維度,這種情況就不起作用了,因此在這種情況下,轉置需要自定義實現。

比較輸出

在完成所有這些之後,讓我們用兩種量化版本執行我們的模型,並確認它們給出相同的輸出!

float_model = ToyModel(64, 128, 32).cuda()
quantized_model_module_swap = copy.deepcopy(float_model)
quantized_model_subclass = copy.deepcopy(float_model)

# Swap torch.nn.Linear with QuantizedLinear
for name, child in quantized_model_module_swap.named_children():
    if type(child) == torch.nn.Linear:
        new_linear = QuantizedLinear.from_float(child)
        setattr(quantized_model_module_swap, name, new_linear)

# Swap torch.nn.Linear weights with Int8SymmetricTensor subclasses
for name, child in quantized_model_subclass.named_children():
    if type(child) == torch.nn.Linear:
        subclass_param = Int8SymmetricTensor.from_float(child.weight)
        child.weight = torch.nn.Parameter(subclass_param, requires_grad=True)

with torch.no_grad():
    x = torch.randn(64, 64, 64, device='cuda')
    out_module_swap = quantized_model_module_swap(x)
    out = quantized_model_subclass(x)
    print(torch.allclose(out, out_module_swap))  # prints True

    # We can also use torch.compile to fuse some of our quantized logic
    out_compiled = torch.compile(quantized_model_subclass)(x)
    print(torch.allclose(out, out_compiled))  # prints True

下一步

在本教程中,我們演示瞭如何構建一個簡單的量化張量子類。這是本系列教程的第一部分。 下一篇文章將討論如何為您的張量子類新增更高階的功能,例如使其可訓練、與 DTensors 組合以及新增張量並行支援。有關 torchao 中 AffineQuantizedTensor 如何使用張量子類構建的更詳細示例,請參閱 此示例

如果您在實現子類時有任何疑問,請隨時在此處 提出問題

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源