convert_to_float8_training¶
- torchao.float8.convert_to_float8_training(module: Module, *, module_filter_fn: Optional[Callable[[Module, str], bool]] = None, config: Optional[Float8LinearConfig] = None) Module[原始碼]¶
將 module 中的 torch.nn.Linear 替換為 Float8Linear。
- 引數:
module – 要修改的模組。
module_filter_fn – 如果指定,則只有透過過濾函式的 torch.nn.Linear 子類才會被替換。傳遞給過濾函式的輸入是模組例項和 FQN。
config (Float8LinearConfig) – 轉換為 float8 的配置
- 返回:
已修改的模組,其中線性層已被替換。
- 返回型別:
nn.Module