快捷方式

量化概述

首先,我們想展示 torchao 的堆疊

Quantization Algorithms/Flows: weight only/dynamic/static quantization, hqq, awq, gptq etc.
---------------------------------------------------------------------------------------------
    Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Float8Tensor
---------------------------------------------------------------------------------------------
  Quantization Primitive Ops/Efficient Kernels: matmul, quantize, dequantize
---------------------------------------------------------------------------------------------
            Basic dtypes: uint1-uint7, int1-int8, float3-float8

任何量化演算法都將使用上述堆疊中的某些元件,例如,每行動態 float8 啟用和 float8 權重量化(預設首選項)使用

基本資料型別

dtype 是一個被過度使用的術語,我們所說的基本資料型別是指不需要任何額外元資料就有意義的資料型別(例如,當人們呼叫 torch.empty(.., dtype) 時有意義)。更多細節請參閱 此帖

無論我們進行何種量化,最終都將使用一些低精度資料型別來表示量化資料或量化引數。與 torchao 相關的低精度資料型別是:

  • PyTorch 2.3 及更高版本中可用的 torch.uint1torch.uint7

  • PyTorch 2.6 及更高版本中可用的 torch.int1torch.int7

  • torch.float4_e2m1fn_x2torch.float8_e4m3fntorch.float8_e4m3fnuztorch.float8_e5m2torch.float8_e5m2fnuztorch.float8_e8m0fnu

在實際實現方面,uint1uint7int1int7 只是佔位符,沒有實際實現(即,對於這些資料型別的 PyTorch Tensor,運算元不起作用)。添加了這些資料型別的示例 PR 可以在 這裡 找到。浮點資料型別是我們稱之為“Shell Dtypes”的資料型別,它們具有有限的運算元支援。

更多詳情請參閱 官方 PyTorch 資料型別文件

注意

諸如 mxfp8、mxfp4、nvfp4 之類的派生資料型別使用這些基本資料型別實現,例如,mxfp4 使用 torch.float8_e8m0fnu 作為 scale,並使用 torch.float4_e2m1fn_x2 作為 4 位資料。

量化原始運算元

量化原始運算元是指用於在低精度量化張量和高精度張量之間進行轉換的運算元。我們主要有以下量化原始運算元:

  • choose_qparams 運算元:根據原始張量選擇量化引數,通常用於動態量化,例如,仿射量化的 scale 和 zero_point

  • quantize 運算元:根據量化引數,將原始高精度張量量化為前一節中提到的資料型別的低精度張量

  • dequantize 運算元:根據量化引數,將低精度張量反量化為高精度張量

為了適應特定用例,上述運算元可能會有所變化,例如,對於靜態量化,我們可能有一個 choose_qparams_affine_with_min_max,它會根據觀察過程中得出的 min/max 值來選擇量化引數。

對於不同的核心庫,我們可以在 torchao 中使用運算元的多個版本,例如,將 bfloat16 張量量化為原始 float8 張量並獲取 scale:_choose_scale_float8_quantize_affine_float8 用於 torchao 實現,以及來自 fbgemm 庫的 torch.ops.triton.quantize_fp8_row

高效核心

我們還將提供與低精度張量一起工作的高效核心,例如:

  • torch.ops.fbgemm.f8f8bf16_rowwise (fbgemm 庫中的行式 float8 啟用和 float8 權重矩陣乘法核心)

  • torch._scaled_mm (PyTorch 中用於行式和張量式計算的 float8 啟用和 float8 權重矩陣乘法核心)

  • int_matmul:接受兩個 int8 張量並輸出一個 int32 張量

  • int_scaled_matmul:執行矩陣乘法並對結果應用 scale。

注意

我們還可以依賴 torch.compile 生成核心(透過 triton),例如,當前的 int8 僅權重量化 核心 僅依靠 torch.compile 來加速。在這種情況下,沒有與量化型別相對應的自定義手寫“高效核心”。

量化張量(派生資料型別和打包格式)

在基本資料型別、量化原始運算元和高效核心的基礎上,我們可以將它們組合起來構建一個量化(低精度)張量,透過繼承 torch.Tensor 來實現。這個張量可以由一個高精度張量和一些引數來構造,這些引數可以配置使用者想要的特定量化。我們也可以稱之為派生資料型別,因為它可以由基本資料型別的張量和一些額外的元資料(如 scale)來表示。

量化張量的另一個維度是打包格式,即量化的原始資料在記憶體中的佈局方式。例如,對於 int4,我們可以將兩個元素打包到一個 uint8 值中,或者人們可以進行一些預混/交換操作,以使格式對於記憶體操作(從記憶體載入到暫存器)和計算更有效。

所以,總的來說,我們透過派生資料型別和打包格式來構造張量子類。

TorchAO 中的張量子類

張量

派生資料型別

打包格式

支援

Float8Tensor

縮放的 float8

普通(無需打包)

float8 啟用 + float8 權重動態量化和 float8 僅權重量化

Int4Tensor

縮放的 int4

普通(將 2 個相鄰的 int4 打包到一個 int8 值中)

int4 僅權重量化

Int4PreshuffledTensor

縮放的 int4

預混(用於最佳化載入的特殊格式)

float8 啟用 + int4 權重動態量化和 int4 僅權重量化

注意

我們沒有粒度特定的張量子類,即沒有 Float8RowwiseTensor 或 Float8BlockwiseTensor,所有粒度都在同一個張量中實現。我們通常使用一個通用的 block_size 屬性來區分不同的粒度,並且每個張量只允許支援所有可能粒度選項的一個子集。

注意

我們也不在名稱中使用動態啟用,因為我們討論的是權重張量物件,在張量子類名稱中包含啟用資訊會造成混淆。但是,我們在同一個線性函式實現中同時實現了僅權重和動態啟用量化,而無需依賴額外的抽象。這使得相關的量化操作(啟用和權重的量化)保持在同一個張量子類中。

在如何量化張量方面,大多數張量使用仿射量化,這意味著低精度張量透過仿射對映從高精度張量量化,即:low_precision_val = high_precision_val / scale + zero_point,其中 scalezero_point 是可以透過量化原始運算元或透過某些最佳化過程計算出的量化引數。另一種常見的量化型別,尤其對於較低的位元寬度(例如低於 4 位)是基於碼本/查詢表的量化,其中原始量化資料是我們可以用來查詢儲存每個索引對應值的 codebook 的索引。一種獲取碼本和用於碼本量化的原始量化資料的方法是 K-means 聚類。

量化演算法/流程

在堆疊的頂部是最終的量化演算法和量化流程。傳統上,我們有僅權重量化、動態量化和靜態量化,但現在我們也看到了更多型別的量化出現。

出於演示目的,假設在前面的步驟之後,我們定義了 Float8TensorFloat8Tensor.from_hp 接受一個高精度浮點張量和一個 target_dtype(例如 torch.float8_e4m3fn)並將其轉換為 Float8Tensor

注意:以下內容均用於解釋概念,有關我們提供的工具和示例的更詳細介紹,請參閱 貢獻者指南

僅權重量化

這是最簡單的量化形式,並且易於將僅權重量化應用於模型,特別是由於我們擁有量化張量。我們所需要做的就是:

linear_module.weight = torch.nn.Parameter(Float8Tensor.from_hp(linear_module.weight, ...), requires_grad=False))

將以上方法應用於模型中的所有線性模組,我們將獲得一個僅權重量化模型。

動態啟用和權重量化

以前稱為“動態量化”,但它意味著我們在執行時動態地量化啟用,並且也量化權重。與僅權重量化相比,主要問題是如何將量化應用於啟用。在 torchao 中,我們傳遞啟用的量化關鍵字引數,當需要時(例如線上性層中),這些關鍵字引數將被應用於啟用。

activation_dtype = torch.float8_e4m3fn
activation_granularity = PerRow()
# define kwargs for float8 activation quantization
act_quant_kwargs = QuantizeTensorToFloat8Kwargs(
  activation_dtype,
  activation_granularity,
)
weight_dtype = torch.float8_e4m3fn
weight_granularity = PerRow()
quantized_weight = Float8Tensor.from_hp(linear_module.weight, float8_dtype=weight_dtype, granularity=weight_granularity, act_quant_kwargs=act_quant_kwargs)
linear_module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False))

靜態啟用量化和權重量化

我們暫時跳過說明,因為我們還沒有看到許多使用基於張量子類的靜態量化流程的用例。我們建議檢視 PT2 匯出量化流程 以進行靜態量化。

其他量化流程

對於不屬於以上任何一種的量化流程/演算法,我們也打算提供常見模式的示例。例如,GPTQ 類量化流程,它被 Autoround 採用,它使用 MultiTensor 和模組鉤子來最佳化模組。

如果您正在開發一種新的量化演算法/流程,並且不確定如何以 PyTorch 原生方式實現它,請隨時提交一個 issue 來描述您的演算法是如何工作的,我們可以幫助您提供實現細節方面的建議。

訓練

上述流程主要側重於推理,但低位元資料型別張量也可用於訓練。

float8 訓練的使用者文件可以在 這裡 找到,微調文件可以在 這裡 找到。

量化感知訓練

TorchAO 也透過 quantize_ API 支援 量化感知訓練

低位元最佳化器

我們支援 低位元最佳化器,它們實現了特定型別的 4 位、8 位和 float8 量化,並且可以與 FSDP 組合(使用查詢表量化)。

量化訓練

我們在 main/torchao/prototype/quantized_training 中有量化訓練原型,並且也可以擴充套件現有的張量子類以支援訓練。初步啟用正在進行中,但仍需要大量後續工作,包括使其適用於不同的核心等。

您還可以檢視關於 量化訓練 的教程,該教程介紹瞭如何使 dtype 張量子類可訓練。

案例研究:torchao 中的 float8 動態啟用和 float8 權重量化是如何工作的?

為了將所有內容連線起來,以下是 torchao 中 float8 動態啟用和 float8 權重量化的更詳細的演練(預設核心首選項,在 H100 上,如果安裝了 fbgemm_gpu_genai 庫)

量化流程:quantize_(model, Float8DynamicActivationFloat8WeightConfig())
  • 發生的情況:linear.weight = torch.nn.Parameter(Float8Tensor.from_hp(linear.weight), requires_grad=False)

  • 量化原始運算元:torch.ops.triton.quantize_fp8_row

  • 量化張量將是 Float8Tensor,一個具有縮放 float8 派生資料型別的量化張量。

模型執行期間:model(input)
  • torch.ops.fbgemm.f8f8bf16_rowwise 在輸入、原始 float8 權重和 scale 上被呼叫。

量化期間

首先,我們從 API 呼叫開始:quantize_(model, Float8DynamicActivationFloat8WeightConfig())。它的作用是將模型中 nn.Linear 模組的權重轉換為 Float8Tensor,採用普通打包格式,無需打包,因為我們有 torch.float8_e4m3fn,它可以直接表示量化的 float8 原始資料而無需額外操作。

  • quantize_:量化權重的模型級 API,透過應用使用者(第二個引數)的配置來實現。

  • Float8DynamicActivationFloat8WeightConfig:float8 動態啟用和 float8 權重量化的配置 * 呼叫量化原始運算元 torch.ops.triton.quantize_fp8_row 將 bfloat16 張量量化為 float8 原始張量並獲取 scale。

模型執行期間

當我們執行量化模型 model(inputs) 時,我們將透過 nn.Linear 的函式式線性運算元。

return F.linear(input, weight, bias)

其中輸入是 bfloat16 張量,權重是 Float8Tensor。它會呼叫 Float8Tensor 子類的 __torch_function__,當輸入之一是 Float8Tensor 時,最終會進入 F.linear 的實現。

@implements([torch.nn.functional.linear, aten.linear.default])
def _(func, types, args, kwargs):
    input_tensor, weight_tensor, bias = (
      args[0],
      args[1],
      args[2] if len(args) > 2 else None,
    )
    # quantizing activation, if `act_quant_kwargs` is specified
    if act_quant_kwargs is not None:
      input_tensor = _choose_quant_func_and_quantize_tensor(
          input_tensor, act_quant_kwargs
      )

    # omitting kernel_preference related code
    # granularity checks, let's say we are doing rowwise quant
    # both input_tensor and weight_tensor will now be Float8Tensor
    xq = input_tensor.qdata.reshape(-1, input_tensor.qdata.shape[-1])
    wq = weight_tensor.qdata.contiguous()
    x_scale = input_tensor.scale
    w_scale = weight_tensor.scale
    res = torch.ops.fbgemm.f8f8bf16_rowwise(
       xq,
       wq,
       x_scale,
       w_scale,
    ).reshape(out_shape)
    return res

該函式首先將輸入量化為 Float8Tensor,然後從輸入張量和權重張量中獲取原始 float 張量和 scale:t.qdatat.scale,並呼叫 fbgemm 核心進行 float8 動態量化的矩陣乘法:torch.ops.fbgemm.f8f8bf16_rowwise

儲存/載入期間

由於 Float8Tensor 權重仍然是 torch.Tensor,因此儲存/載入與原始高精度浮點模型的工作方式相同。更多詳情請參閱 序列化文件

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源