與 VLLM 整合:架構和使用指南¶
本教程全面概述了 TorchAO 如何與 VLLM 整合,以及需要實現哪些內容才能使一項新技術能夠端到端工作。
配置系統¶
1. HuggingFace 模型配置¶
TorchAO 量化透過模型的 config.json 檔案進行配置
{
"model_type": "llama",
"quant_type": {
"default": {
"_type": "Int4WeightOnlyConfig",
"_data": {
"group_size": 128,
"use_hqq": true
}
}
}
}
2. TorchAO 配置類¶
所有量化方法都繼承自 AOBaseConfig
from torchao.core.config import AOBaseConfig
from torchao.quantization import Int4WeightOnlyConfig
# Example configuration
config = Int4WeightOnlyConfig(
group_size=128,
use_hqq=True,
)
assert isinstance(config, AOBaseConfig)
注意
所有量化配置都繼承自 torchao.core.config.AOBaseConfig,它提供了序列化和驗證功能。
3. 模組級配置¶
為了進行精細控制,請使用 ModuleFqnToConfig
from torchao.quantization import ModuleFqnToConfig, Int4WeightOnlyConfig, Int8WeightOnlyConfig
config = ModuleFqnToConfig({
"model.layers.0.self_attn.q_proj": Int4WeightOnlyConfig(group_size=64),
"model.layers.0.self_attn.k_proj": Int4WeightOnlyConfig(group_size=64),
"model.layers.0.mlp.gate_proj": Int8WeightOnlyConfig(),
"_default": Int4WeightOnlyConfig(group_size=128) # Default for other modules
})
使用示例¶
1. 使用 HuggingFace 整合量化模型¶
from transformers import TorchAoConfig, AutoModelForCausalLM
from torchao.quantization import Int4WeightOnlyConfig
# Create quantization configuration
quantization_config = TorchAoConfig(
quant_type=Int4WeightOnlyConfig(group_size=128, use_hqq=True)
)
# Load and automatically quantize the model
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B",
torch_dtype="auto",
device_map="auto",
quantization_config=quantization_config
)
# Save quantized model (see Serialization section below for safe_serialization details)
model.push_to_hub("your-username/Llama-3.2-1B-int4", safe_serialization=False)
另請參閱
有關量化配置的更多資訊,請參閱 torchao.quantization.Int4WeightOnlyConfig 和 torchao.quantization.Int8WeightOnlyConfig。
2. 使用 VLLM 進行服務¶
# Start VLLM server with TorchAO quantized model
vllm serve your-username/Llama-3.2-1B-int4 \
--quantization torchao \
--dtype bfloat16 \
為 VLLM 新增新的量化方法¶
VLLM 相容性的最低要求¶
要使新的 TorchAO 量化方法能夠與 VLLM 一起工作,您需要實現支援 **張量並行** 的最低張量子類操作。VLLM 使用 narrow() 和 copy_() 將資料從 state dict 中載入到裝置上,這些操作需要特定的 aten 操作。
為什麼是這些?¶
VLLM 的張量並行工作方式是:
一個有用的實現模式是 _apply_fn_to_data,它將一個給定函式應用於類中所有具有 Tensor 型別的屬性。下面是一個通用的實現,對於大多數子類都應該有效。我們在 torchao 程式碼庫中大量使用了這種模式。
def _apply_fn_to_data(self, fn: Callable):
"""Applies a fn to all tensor components stored on this class"""
tensor_names, ctx = self.__tensor_flatten__()
# Apply the function to each tensor component
new_tensors = {}
for name in tensor_names:
new_tensors[name] = fn(getattr(self, name))
return self.__class__.__tensor_unflatten__(
new_tensors,
ctx,
None, # outer_size parameter
None, # outer_stride parameter
)
新增新量化方法的步驟指南¶
1. 建立您的 Tensor 子類¶
注意
有關張量子類及其設計原則的更多詳細資訊,請參閱 什麼是張量子類? 文件。
from torchao.core.config import AOBaseConfig
from torchao.utils import TorchAOBaseTensor
@dataclass
class MyNewQuantConfig(AOBaseConfig):
"""Configuration for your new quantization method"""
bits: int = 8
VERSION: ClassVar[int] = 1
class MyQuantizedTensor(TorchAOBaseTensor):
"""Example based on FbgemmFp8Tensor - stores quantized data + scale"""
tensor_data_attrs = ["quantized_data", "scale"]
tensor_attributes = ["dtype"]
def __new__(cls, quantized_data, scale, dtype):
shape = quantized_data.shape
return torch.Tensor._make_wrapper_subclass(
cls, shape, device=quantized_data.device, dtype=dtype, requires_grad=False
)
def __init__(self, quantized_data, scale, dtype):
self.quantized_data = quantized_data
self.scale = scale
def __tensor_flatten__(self) -> Tuple[List[str], List]:
"""Serialize tensor subclass into plain tensors and metadata"""
return self.tensor_data_attrs, [
getattr(self, attr) for attr in self.tensor_attributes
]
@classmethod
def __tensor_unflatten__(
cls,
tensor_data_dict: Dict[str, torch.Tensor],
tensor_attributes: List,
outer_size: Optional[torch.Size],
outer_stride: Optional[Tuple],
) -> "MyQuantizedTensor":
"""Reconstruct tensor subclass from serialized data"""
return cls(
*[tensor_data_dict[name] for name in cls.tensor_data_attrs],
*tensor_attributes,
)
2. 實現所需的 VLLM 操作¶
from torch.utils._python_dispatch import return_and_correct_aliasing
@MyQuantizedTensor.implements([aten.detach.default, aten.alias.default])
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(func)
)
@MyQuantizedTensor.implements([aten._to_copy.default])
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)
@MyQuantizedTensor.implements([aten.slice.Tensor])
def _(func, types, args, kwargs):
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
if dim == 0 or dim == 1:
# NOTE the slicing here will likely be different for different quant techniques
return return_and_correct_aliasing(
func, args, kwargs,
args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step))
)
else:
raise NotImplementedError(f"Slicing along dim={dim} not supported")
3. 註冊到 TorchAO 的量化系統¶
from torchao.quantization.transform_module import register_quantize_module_handler
@register_quantize_module_handler(MyNewQuantConfig)
def _my_quant_transform(module: torch.nn.Module, config: MyNewQuantConfig):
"""Transform function that applies your quantization to a module"""
weight = module.weight
# Your quantization logic here
quantized_weight = my_quantization_function(weight, config)
# Replace the weight with your quantized tensor
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
return module
重要提示
使用 torchao.quantization.transform_module.register_quantize_module_handler() 裝飾器可以將您的配置類註冊到 TorchAO 的量化系統中。
關鍵實現細節¶
特定硬體的線性運算¶
您的量化張量的前向傳播決定了硬體支援,以及當呼叫 torch.nn.functional.linear() 時實際執行的操作。
@MyQuantizedTensor.implements(torch.nn.functional.linear)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = args[0], args[1], args[2] if len(args) > 2 else None
# This is where you define what hardware your method supports
if hasattr(weight_tensor, 'use_cutlass_kernel'):
return my_cutlass_linear(input_tensor, weight_tensor, bias)
elif hasattr(weight_tensor, 'use_triton_kernel'):
return my_triton_linear(input_tensor, weight_tensor, bias)
else:
# Fallback - dequantize and use standard linear
return torch.nn.functional.linear(
input_tensor, weight_tensor.dequantize(), bias
)
編譯優勢¶
使用 torch.compile() 後,張量子類的開銷會消失,這在 VLLM 中是預設啟用的。
Tensor 子類的權衡¶
**編譯**:對於消除子類開銷至關重要。如果沒有編譯,除非您的模型 GPU 佔用率非常高,否則 CPU 上的分派開銷可能會嚴重影響效能。
檢查點定義了模型的行為。您可能會說“難道所有檢查點都這樣做嗎?”。這是正確的,但人們通常只將 torch.Tensor 視為其資料。而實際上,它是一個真正的類,它帶來了 Dispatcher 以及 ATen 註冊到的所有 Kernel。當您定義自己的張量子類時,您正在構建一個獨立的微型世界。這個世界有一個不同的資料表示,但也需要您明確定義支援哪些操作,併為想要支援的所有硬體提供實現。起初,這可能感覺有點像“幽靈的遠距離作用”。但它可以非常強大。例如,僅透過 3 個定義就能支援 TP。
序列化和模型共享¶
SafeTensors 支援¶
**當前狀態**:由於張量子類的限制,TorchAO 量化模型還無法使用 safetensors 進行序列化。儲存量化模型時,必須使用 safe_serialization=False。
**解決方法**:對於生產使用,將模型推送到 HuggingFace Hub 時,請使用 safe_serialization=False 儲存模型。
**未來工作**:TorchAO 團隊正在積極研究對張量子類的 safetensors 支援。請在此跟蹤進度:pytorch/ao#2338
整合架構圖¶
1. 高階模型流程:Transformers → VLLM + TorchAO¶
此圖顯示了從模型建立到服務的端到端流程。
graph LR
A[HuggingFace Model] --> B[Transformers AutoModel]
B --> C{Quantization Config?}
C -->|TorchAO Config| D[Apply TorchAO Quantization]
C -->|No Config| E[Standard Model]
D --> F[Quantized Model w/ Tensor Subclasses]
E --> G[Standard PyTorch Model]
F --> H[VLLM Model Loading]
G --> H
H --> I[VLLM Distributed Engine]
I --> J[Tensor Parallel Sharding]
J --> K[Optimized Inference]
style D fill:#e1f5fe
style F fill:#f3e5f5
style J fill:#e8f5e8
2. VLLM 中的 TorchAO 整合點¶
此圖顯示了 VLLM 如何檢測和應用 TorchAO 量化。
graph LR
A[Model Config Detection] --> B{quantization=torchao?}
B -->|Yes| C[TorchAOConfig.from_config]
B -->|No| D[Other Quantization Methods]
C --> E[Parse HF quant_type]
E --> F[config_from_dict]
F --> G[AOBaseConfig Instance]
G --> H[get_quant_method per layer]
H --> I{Layer Type?}
I -->|LinearBase| J[TorchAOLinearMethod]
I -->|Other| K[UnquantizedLinearMethod]
J --> L[create_weights]
L --> M[torchao_quantize_param_data]
M --> N[Quantized Tensor Subclass]
style C fill:#e1f5fe
style G fill:#f3e5f5
style N fill:#e8f5e8
3. Kernel 分派:將外部 Kernel 引入 VLLM¶
此圖說明了張量子類如何實現 VLLM 內部的自定義 Kernel 分派。
graph LR
A[F.linear Call in VLLM] --> B[MyQuantTensor torch_function]
B --> C[Custom implements Handler]
C --> D{Hardware Check}
D --> E[Dispatch to External Kernel]
E --> F[Execute Optimized Kernel]
F --> G[Return Result to VLLM]
subgraph "External Libraries"
H[TorchAO CUTLASS]
I[TorchAO Triton]
J[FBGEMM-GPU]
K[Custom Libraries]
end
subgraph "Tensor Subclass Code"
L[implements F.linear]
M[custom_linear_impl]
N[call external kernel]
end
E --> H
E --> I
E --> J
E --> K
C --> L
L --> M
M --> N
N --> E
style B fill:#e8f6ff,color:#000
style C fill:#fff3e0,color:#000
style E fill:#e8f5e8,color:#000
style L fill:#f3e5f5,color:#000