• 文件 >
  • (第 1 部分)使用 float8 進行預訓練
快捷方式

(第一部分) 使用 float8 進行預訓練

TorchAO 透過利用我們整合到合作伙伴框架中的量化和稀疏技術,提供端到端的預訓練、微調和推理模型最佳化流程。這是展示此端到端流程的 3 個教程中的第 1 部分,重點關注預訓練步驟。

_images/e2e_flow_part1.png

使用 torchao 和 float8 進行預訓練可以在 512 個 GPU 叢集上提供高達 1.5 倍的加速,在 2K H200 叢集上使用最新的 torchao.float8 行式(rowwise)配方可提供高達 1.34-1.43 倍的加速

在本教程中,我們將展示使用 torchao.float8 配方進行預訓練的兩種方法:

  1. 使用 torchtitan 進行預訓練,這是 PyTorch 官方預訓練框架,具有原生 torchao 整合。

  2. 直接使用 torchao 進行預訓練,將 torchao 的 float8 訓練配方整合到您自己的預訓練程式碼中。

使用 torchtitan 進行預訓練

在本教程中,我們將使用 torchtitan 和 torchao 的 float8 訓練配方(行式縮放和張量式縮放)來預訓練 Llama3-8B。

Torchtitan 是 PyTorch 的官方預訓練框架,它與 torchao 原生整合,並支援多種流行的旗艦模型,具有常見的並行形式、float8 訓練、分散式檢查點等。有關更多詳細資訊,請參閱 torchtitan 的文件

您可以使用此工作流程快速開始“開箱即用”的體驗。使用者通常會 fork torchtitan,並在準備好後在其基礎上進行構建。

先決條件

  1. (推薦) 使用 conda 或 venv 建立一個新的虛擬環境。

  2. 安裝 torchao.

  3. 安裝 torchtitan,包括“下載分詞器”步驟。

現在您可以開始使用以下任一配方進行預訓練作業了!

行式縮放

在 torchtitan 的根目錄下執行以下命令,以啟動一個在 8 個 GPU 上使用 float8 行式訓練的 Llama3-8B 訓練作業:

NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.compile --model.converters="float8" --float8.recipe_name="rowwise"

當使用 1 個以上的 GPU 時,Torchtitan 將自動使用 FSDP2 來並行化訓練。要使用其他並行形式、修改超引數或更改其他訓練配置,您可以直接編輯 llama3_8b.toml 檔案或使用命令列標誌(執行命令並帶 --help 引數可檢視更多選項)。

您應該會看到類似以下的終端輸出:

[rank0]:[titan] 2025-06-04 08:51:48,074 - root - INFO - step:  1  loss: 12.2254  memory: 27.34GiB(28.78%)  tps: 375  tflops: 21.73  mfu: 2.20%
[rank0]:[titan] 2025-06-04 08:51:58,557 - root - INFO - step: 10  loss: 10.7069  memory: 30.99GiB(32.62%)  tps: 7,034  tflops: 407.35  mfu: 41.19%
[rank0]:[titan] 2025-06-04 08:52:10,224 - root - INFO - step: 20  loss:  8.9196  memory: 30.99GiB(32.62%)  tps: 7,022  tflops: 406.65  mfu: 41.12%
[rank0]:[titan] 2025-06-04 08:52:21,904 - root - INFO - step: 30  loss:  8.1423  memory: 30.99GiB(32.62%)  tps: 7,014  tflops: 406.23  mfu: 41.08%

如您所見,忽略預熱步驟,我們達到了約 7k TPS,峰值記憶體使用量為 30.99GB。為了與 bfloat16 訓練進行效能比較,您可以刪除 --model.converters="float8" --float8.recipe_name="rowwise" 標誌並執行相同的命令,以檢視 bfloat16 訓練的基線效能。

NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.compile

您應該會看到以下輸出:

[rank0]:[titan] 2025-06-04 11:02:37,404 - root - INFO - step:  1  loss: 12.2611  memory: 27.22GiB(28.65%)  tps: 595  tflops: 34.47  mfu: 3.49%
[rank0]:[titan] 2025-06-04 11:02:49,027 - root - INFO - step: 10  loss: 10.4260  memory: 30.89GiB(32.51%)  tps: 6,344  tflops: 367.39  mfu: 37.15%
[rank0]:[titan] 2025-06-04 11:03:01,988 - root - INFO - step: 20  loss:  8.9482  memory: 30.89GiB(32.51%)  tps: 6,321  tflops: 366.06  mfu: 37.01%
[rank0]:[titan] 2025-06-04 11:03:14,991 - root - INFO - step: 30  loss:  8.1183  memory: 30.89GiB(32.51%)  tps: 6,300  tflops: 364.89  mfu: 36.89%
[rank0]:[titan] 2025-06-04 11:03:28,013 - root - INFO - step: 40  loss:  7.4659  memory: 30.89GiB(32.51%)  tps: 6,291  tflops: 364.36  mfu: 36.84%
[rank0]:[titan] 2025-06-04 11:03:39,769 - root - INFO - [GC] Peforming periodical GC collection. 0.02 seconds.

如您所見,bfloat16 基線達到了約 6.3k TPS,峰值記憶體使用量為 30.89GB。

這意味著我們的 float8 行式縮放配方比 bfloat16 基線實現了1.11 倍的更高吞吐量,同時峰值記憶體使用量幾乎相同!

請注意,使用張量式縮放配方可以實現更高的吞吐量提升,該配方在效能與準確度之間存在不同的權衡。

張量式縮放

使用張量式縮放的 Float8 訓練是預設配方,因此我們可以省略 --float8.recipe_name 標誌。

NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.compile --model.converters="float8"

您應該會看到類似以下的輸出:

[rank0]:[titan] 2025-06-04 10:52:19,648 - root - INFO - step:  1  loss: 12.2648  memory: 27.28GiB(28.71%)  tps: 557  tflops: 32.29  mfu: 3.26%
[rank0]:[titan] 2025-06-04 10:52:29,475 - root - INFO - step: 10  loss: 10.9106  memory: 30.91GiB(32.53%)  tps: 7,503  tflops: 434.53  mfu: 43.94%
[rank0]:[titan] 2025-06-04 10:52:40,166 - root - INFO - step: 20  loss:  9.0774  memory: 30.91GiB(32.53%)  tps: 7,663  tflops: 443.78  mfu: 44.87%
[rank0]:[titan] 2025-06-04 10:52:50,885 - root - INFO - step: 30  loss:  8.3233  memory: 30.91GiB(32.53%)  tps: 7,643  tflops: 442.66  mfu: 44.76%
[rank0]:[titan] 2025-06-04 10:53:01,613 - root - INFO - step: 40  loss:  7.6150  memory: 30.91GiB(32.53%)  tps: 7,637  tflops: 442.27  mfu: 44.72%

如您所見,我們達到了約 7.6k TPS,峰值記憶體使用量為 30.91GB,這比 bfloat16 基線高出 1.21 倍的吞吐量

選擇配方

簡而言之:行式縮放更適合優先考慮更準確的數值和訓練穩定性的作業,而張量式更適合優先考慮訓練吞吐量的作業。

張量式縮放的更高吞吐量是以略高的量化誤差為代價的(即,與行式縮放相比,數值完整性有所降低)。這是因為行式縮放使用更精細的縮放因子(每行而不是每張量),這限制了可能導致縮放過程中下溢的異常值的影響。

您可以在下面看到在 8xH100 GPU 上訓練 Llama3-8B 時,bfloat16、float8 張量式和 float8 行式訓練的損失曲線對比:

Loss curves for training Llama3-8B on 8xH100s with torchtitan using bfloat16, float8 tensorwise, and float8 rowwise training.

重要說明

  • 目前,torchtitan 中的 float8 訓練僅支援 2 個及以上的 GPU,不支援單個 GPU 訓練。

  • 您必須使用 --training.compile 來實現高效能。torchao float8 訓練配方是基於 torch.compile 原生構建的,因此可以直接使用!

直接使用 torchao 進行預訓練

在本教程中,我們將直接使用 torchao API 預訓練一個玩具模型。

您可以使用此工作流程將 torchao 直接整合到您自己的自定義預訓練程式碼中。

先決條件

  1. (推薦) 使用 conda 或 venv 建立一個新的虛擬環境。

  2. 安裝 torchao.

現在您可以直接將 torchao 整合到您的訓練程式碼中了!

模型轉換 API

用於將模型轉換為使用 float8 訓練的 torchao API 是:convert_to_float8_training。此 API 將遞迴地將模型中的 nn.Linear 模組轉換為使用 Float8Linear

您可以使用 module_filter_fn 引數來確定哪些 nn.Linear 層應被替換為使用 Float8Linear

您應該參考此效能基準表,以瞭解相對於 bfloat16,對於給定的 GEMM 大小可以預期什麼樣的效能提升。

下面是一個展示如何使用它的程式碼片段:

import torch
from torch import nn
import torch.nn.functional as F

from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_linear import Float8Linear
from torchao.float8 import convert_to_float8_training

# create model and sample input
m = nn.Sequential(
    nn.Linear(2048, 4096),
    nn.Linear(4096, 128),
    nn.Linear(128, 1),
).bfloat16().cuda()
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

# optional: filter modules from being eligible for float8 conversion
def module_filter_fn(mod: torch.nn.Module, fqn: str):
    # don't convert the last module
    if fqn == "1":
        return False
    # don't convert linear modules with weight dimensions not divisible by 16
    if isinstance(mod, torch.nn.Linear):
        if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
            return False
    return True

# convert specified `torch.nn.Linear` modules to `Float8Linear`
convert_to_float8_training(m, module_filter_fn=module_filter_fn)

# enable torch.compile for competitive performance
m = torch.compile(m)

# toy training loop
for _ in range(10):
    optimizer.zero_grad()
    output = m(x)
    # use fake labels for demonstration purposes
    fake_labels = torch.ones_like(output)
    loss = F.mse_loss(output, fake_labels)
    loss.backward()
    optimizer.step()

# save the model
torch.save({
    'model': m,
    'model_state_dict': m.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, 'checkpoint.pth')

在預訓練模型後,您可以選擇將其微調到更特定於域的資料集,併為其在推理時的量化進行適配。在本教程的下一部分中,我們將探討在微調步驟中的一些模型最佳化選項。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源