torchao.float8¶
主要的 float8 訓練 API¶
將 module 中的 torch.nn.Linear 替換為 Float8Linear。 |
其他 float8 訓練型別¶
用於將 torch.nn.Linear 模組轉換為 float8 進行訓練的配置。 |
|
用於將單個張量可能轉換為 float8 的配置 |
|
定義轉換為 float8 的縮放策略的粒度 |
|
為所有 float8 引數動態計算尺度。此函式應在最佳化器步驟之後執行。它執行一次 all-reduce 以計算所有 float8 權重的尺度。示例用法:model(input).sum().backward() optim.step() precompute_float8_dynamic_scale_for_fsdp(model)。 |