基準測試 API 指南¶
本教程將指導您使用 TorchAO 基準測試框架。本教程包含將新 API 整合到框架和儀表板中。
將 API 新增到基準測試實用技巧¶
該框架目前支援量化和稀疏性實用技巧,可以使用 `quantize_()` 或 `sparsity_()` 函式執行。
要新增新的實用技巧,請在 `benchmarks/microbenchmarks/utils.py` 檔案中的 `string_to_config()` 函式中新增相應的字串配置。
def string_to_config(
quantization: Optional[str], sparsity: Optional[str], **kwargs
) -> AOBaseConfig:
# ... existing code ...
elif quantization == "my_new_quantization":
# If additional information needs to be passed as kwargs, process it here
return MyNewQuantizationConfig(**kwargs)
elif sparsity == "my_new_sparsity":
return MyNewSparsityConfig(**kwargs)
# ... rest of existing code ...
現在,我們可以在整個基準測試框架中使用此實用技巧。
注意: 如果 `AOBaseConfig` 使用輸入引數,例如位寬、分組大小等,您可以在輸入中附加到字串配置中傳遞它們。例如,對於 `GemliteUIntXWeightOnlyConfig`,我們可以將位寬和分組大小作為 `gemlitewo-<bit_width>-<group_size>` 傳遞。
將模型新增到基準測試實用技巧¶
要將新的模型架構新增到基準測試系統,您需要修改 `torchao/testing/model_architectures.py`。
要新增新的模型型別,請在 `
torchao/testing/model_architectures.py` 中定義您的模型類。
class MyCustomModel(torch.nn.Module):
def __init__(self, input_dim, output_dim, dtype=torch.bfloat16):
super().__init__()
# Define your model architecture
self.layer1 = torch.nn.Linear(input_dim, 512, bias=False).to(dtype)
self.activation = torch.nn.ReLU()
self.layer2 = torch.nn.Linear(512, output_dim, bias=False).to(dtype)
def forward(self, x):
x = self.layer1(x)
x = self.activation(x)
x = self.layer2(x)
return x
更新 `create_model_and_input_data` 函式以處理您的新模型型別。
def create_model_and_input_data(
model_type: str,
m: int,
k: int,
n: int,
high_precision_dtype: torch.dtype = torch.bfloat16,
device: str = "cuda",
activation: str = "relu",
):
# ... existing code ...
elif model_type == "my_custom_model":
model = MyCustomModel(k, n, high_precision_dtype).to(device)
input_data = torch.randn(m, k, device=device, dtype=high_precision_dtype)
# ... rest of existing code ...
模型設計注意事項¶
新增新模型時
輸入/輸出維度:確保您的模型處理 (m, k, n) 維度約定,其中
m:批次大小或序列長度k:輸入特徵維度n:輸出特徵維度
資料型別:支援 `high_precision_dtype` 引數(通常是 `torch.bfloat16`)。
裝置相容性:確保您的模型能在 CUDA、CPU 和其他目標裝置上執行。
量化相容性:設計您的模型以與 TorchAO 量化方法相容。
將 HF 模型新增到基準測試實用技巧¶
(即將推出!!!)
將 API 新增到基準測試 CI 儀表板¶
要將您的 API 整合到 CI 儀表板
1. 修改現有的 CI 配置¶
在 `benchmarks/dashboard/microbenchmark_quantization_config.yml` 檔案中將您的量化方法新增到現有的 CI 配置檔案中。
# benchmarks/dashboard/microbenchmark_quantization_config.yml
benchmark_mode: "inference"
quantization_config_recipe_names:
- "int8wo"
- "int8dq"
- "float8dq-tensor"
- "float8dq-row"
- "float8wo"
- "my_new_quantization" # Add your method here
output_dir: "benchmarks/microbenchmarks/results"
model_params:
- name: "small_bf16_linear"
matrix_shapes:
- name: "small_sweep"
min_power: 10
max_power: 15
high_precision_dtype: "torch.bfloat16"
torch_compile_mode: "max-autotune"
device: "cuda"
model_type: "linear"
2. 執行 CI 基準測試¶
使用 CI 執行器生成 PyTorch OSS 基準測試資料庫格式的結果。
python benchmarks/dashboard/ci_microbenchmark_runner.py \
--config benchmarks/dashboard/microbenchmark_quantization_config.yml \
--output benchmark_results.json
3. CI 輸出格式¶
CI 執行器以 PyTorch OSS 基準測試資料庫所需的特定 JSON 格式輸出結果。
[
{
"benchmark": {
"name": "micro-benchmark api",
"mode": "inference",
"dtype": "int8wo",
"extra_info": {
"device": "cuda",
"arch": "NVIDIA A100-SXM4-80GB"
}
},
"model": {
"name": "1024-1024-1024",
"type": "micro-benchmark custom layer",
"origins": ["torchao"]
},
"metric": {
"name": "speedup(wrt bf16)",
"benchmark_values": [1.25],
"target_value": 0.0
},
"runners": [],
"dependencies": {}
}
]
4. 與 CI 流水線整合¶
要與您的 CI 流水線整合,請將基準測試步驟新增到您的工作流中。
# Example GitHub Actions step
- name: Run Microbenchmarks
run: |
python benchmarks/dashboard/ci_microbenchmark_runner.py \
--config benchmarks/dashboard/microbenchmark_quantization_config.yml \
--output benchmark_results.json
- name: Upload Results
# Upload benchmark_results.json to your dashboard system
故障排除¶
執行測試¶
驗證您的設定並執行測試套件。
python -m unittest discover benchmarks/microbenchmarks/test
常見問題¶
CUDA 記憶體不足:減小批次大小或矩陣維度。
缺少量化方法:確保 TorchAO 已正確安裝。
裝置不可用:檢查裝置可用性和驅動程式。
最佳實踐¶
使用 `small_sweep` 進行基本測試,使用 `custom shapes` 進行全面或特定模型的分析。
僅在需要時啟用分析(會增加開銷)。
儘可能在多個裝置上進行測試。
使用一致的命名約定以提高可重現性。
有關基準測試不同用例的資訊,請參閱 基準測試使用者指南。
有關框架元件更詳細的資訊,請參閱 `benchmarks/microbenchmarks/` 目錄中的 README 檔案。