自動混合精度包 - torch.amp#
建立日期:2025 年 6 月 12 日 | 最後更新日期:2025 年 6 月 12 日
torch.amp 提供了混合精度的便捷方法,其中一些操作使用 torch.float32 (float)資料型別,而其他操作使用較低精度的浮點資料型別(lower_precision_fp):torch.float16 (half)或 torch.bfloat16。一些操作,如線性層和卷積,在 lower_precision_fp 中速度更快。其他操作,如歸約,通常需要 float32 的動態範圍。混合精度試圖將每個操作與其適當的資料型別相匹配。
通常,使用 torch.float16 資料型別的“自動混合精度訓練”會一起使用 torch.autocast 和 torch.amp.GradScaler,如 自動混合精度示例 和 自動混合精度食譜 中所示。然而,torch.autocast 和 torch.GradScaler 是模組化的,如果需要,可以單獨使用。如 torch.autocast 的 CPU 示例部分所示,“CPU 上的自動混合精度訓練/推理”使用 torch.bfloat16 資料型別,只使用 torch.autocast。
警告
torch.cuda.amp.autocast(args...) 和 torch.cpu.amp.autocast(args...) 已棄用。請改用 torch.amp.autocast("cuda", args...) 或 torch.amp.autocast("cpu", args...)。torch.cuda.amp.GradScaler(args...) 和 torch.cpu.amp.GradScaler(args...) 已棄用。請改用 torch.amp.GradScaler("cuda", args...) 或 torch.amp.GradScaler("cpu", args...)。
torch.autocast 和 torch.cpu.amp.autocast 在版本 1.10 中是新加入的。
自動型別轉換#
- torch.amp.autocast_mode.is_autocast_available(device_type)[source]#
返回一個布林值,指示
device_type上是否可用自動型別轉換。- 引數
device_type (str) – 要使用的裝置型別。可能的值包括:“cuda”、“cpu”、“mtia”、“maia”、“xpu”等。該型別與
torch.device的 type 屬性相同。因此,您可以使用 Tensor.device.type 獲取張量的裝置型別。- 返回型別
- class torch.autocast(device_type, dtype=None, enabled=True, cache_enabled=None)[source]#
autocast的例項用作上下文管理器或裝飾器,允許您的指令碼的某些區域以混合精度執行。在這些區域中,操作以自動型別轉換選擇的特定於操作的資料型別執行,以提高效能同時保持準確性。有關詳細資訊,請參閱 Autocast 操作參考。
進入啟用自動型別轉換的區域時,張量可以是任何型別。在使用自動型別轉換時,不應在模型或輸入上呼叫
half()或bfloat16()。autocast應僅包裝網路的正向傳播(包括損失計算)。不建議在自動型別轉換下進行反向傳播。反向操作以與自動型別轉換用於相應正向操作相同的型別執行。CUDA 裝置示例
# Creates model and optimizer in default precision model = Net().cuda() optimizer = optim.SGD(model.parameters(), ...) for input, target in data: optimizer.zero_grad() # Enables autocasting for the forward pass (model + loss) with torch.autocast(device_type="cuda"): output = model(input) loss = loss_fn(output, target) # Exits the context manager before backward() loss.backward() optimizer.step()
有關更復雜的場景(例如,梯度懲罰、多個模型/損失、自定義 autograd 函式)中的用法(以及梯度縮放),請參閱 自動混合精度示例。
autocast也可以用作裝飾器,例如,在模型的forward方法上class AutocastModel(nn.Module): ... @torch.autocast(device_type="cuda") def forward(self, input): ...
在啟用自動型別轉換的區域中生成的浮點張量可能是
float16。返回到停用自動型別轉換的區域後,將它們與不同資料型別的浮點張量一起使用可能會導致型別不匹配錯誤。如果是這種情況,請將從自動型別轉換區域生成的張量轉換回float32(或所需的其他資料型別)。如果來自自動型別轉換區域的張量已經是float32,則轉換操作無效,不會產生額外開銷。CUDA 示例# Creates some tensors in default dtype (here assumed to be float32) a_float32 = torch.rand((8, 8), device="cuda") b_float32 = torch.rand((8, 8), device="cuda") c_float32 = torch.rand((8, 8), device="cuda") d_float32 = torch.rand((8, 8), device="cuda") with torch.autocast(device_type="cuda"): # torch.mm is on autocast's list of ops that should run in float16. # Inputs are float32, but the op runs in float16 and produces float16 output. # No manual casts are required. e_float16 = torch.mm(a_float32, b_float32) # Also handles mixed input types f_float16 = torch.mm(d_float32, e_float16) # After exiting autocast, calls f_float16.float() to use with d_float32 g_float32 = torch.mm(d_float32, f_float16.float())
CPU 訓練示例
# Creates model and optimizer in default precision model = Net() optimizer = optim.SGD(model.parameters(), ...) for epoch in epochs: for input, target in data: optimizer.zero_grad() # Runs the forward pass with autocasting. with torch.autocast(device_type="cpu", dtype=torch.bfloat16): output = model(input) loss = loss_fn(output, target) loss.backward() optimizer.step()
CPU 推理示例
# Creates model in default precision model = Net().eval() with torch.autocast(device_type="cpu", dtype=torch.bfloat16): for input in data: # Runs the forward pass with autocasting. output = model(input)
帶 Jit Trace 的 CPU 推理示例
class TestModel(nn.Module): def __init__(self, input_size, num_classes): super().__init__() self.fc1 = nn.Linear(input_size, num_classes) def forward(self, x): return self.fc1(x) input_size = 2 num_classes = 2 model = TestModel(input_size, num_classes).eval() # For now, we suggest to disable the Jit Autocast Pass, # As the issue: https://github.com/pytorch/pytorch/issues/75956 torch._C._jit_set_autocast_mode(False) with torch.cpu.amp.autocast(cache_enabled=False): model = torch.jit.trace(model, torch.randn(1, input_size)) model = torch.jit.freeze(model) # Models Run for _ in range(3): model(torch.randn(1, input_size))
在啟用自動型別轉換的區域中發生的型別不匹配錯誤是一個 bug;如果您遇到這種情況,請提交問題。
autocast(enabled=False)子區域可以巢狀在啟用自動型別轉換的區域中。區域性停用自動型別轉換可能很有用,例如,如果您想強制一個子區域以特定dtype執行。停用自動型別轉換可讓您顯式控制執行型別。在子區域中,應在內部使用前將來自周圍區域的輸入轉換為dtype。# Creates some tensors in default dtype (here assumed to be float32) a_float32 = torch.rand((8, 8), device="cuda") b_float32 = torch.rand((8, 8), device="cuda") c_float32 = torch.rand((8, 8), device="cuda") d_float32 = torch.rand((8, 8), device="cuda") with torch.autocast(device_type="cuda"): e_float16 = torch.mm(a_float32, b_float32) with torch.autocast(device_type="cuda", enabled=False): # Calls e_float16.float() to ensure float32 execution # (necessary because e_float16 was created in an autocasted region) f_float32 = torch.mm(c_float32, e_float16.float()) # No manual casts are required when re-entering the autocast-enabled region. # torch.mm again runs in float16 and produces float16 output, regardless of input types. g_float16 = torch.mm(d_float32, f_float32)
自動型別轉換狀態是執行緒本地的。如果您希望在新執行緒中啟用它,則必須在該執行緒中呼叫上下文管理器或裝飾器。這會影響
torch.nn.DataParallel和torch.nn.parallel.DistributedDataParallel在用於多個 GPU(每個程序)時(請參閱 使用多個 GPU)。- 引數
device_type (str, required) – 要使用的裝置型別。可能的值包括:“cuda”、“cpu”、“mtia”、“maia”、“xpu”和“hpu”。該型別與
torch.device的 type 屬性相同。因此,您可以使用 Tensor.device.type 獲取張量的裝置型別。enabled (bool, optional) – 是否應在區域中啟用自動型別轉換。預設值:
Truedtype (torch_dtype, optional) – 在自動型別轉換中執行的操作的資料型別。如果
dtype為None,它將使用get_autocast_dtype()提供的預設值(CUDA 為torch.float16,CPU 為torch.bfloat16)。預設值:Nonecache_enabled (bool, optional) – 是否應啟用自動型別轉換內部的權重快取。預設值:
True
- torch.amp.custom_fwd(fwd=None, *, device_type, cast_inputs=None)[source]#
建立用於自定義 autograd 函式的
forward方法的輔助裝飾器。Autograd 函式是
torch.autograd.Function的子類。有關更多詳細資訊,請參閱 示例頁面。- 引數
device_type (str) – 要使用的裝置型別。“cuda”、“cpu”、“mtia”、“maia”、“xpu”等。該型別與
torch.device的 type 屬性相同。因此,您可以使用 Tensor.device.type 獲取張量的裝置型別。cast_inputs (
torch.dtype或 None, optional, default=None) – 如果不為None,當forward在啟用自動型別轉換的區域中執行時,它會將傳入的浮點張量轉換為目標資料型別(非浮點張量不受影響),然後停用自動型別轉換執行forward。如果為None,則forward的內部操作將與當前的自動型別轉換狀態一起執行。
注意
如果裝飾的
forward在自動型別轉換停用區域之外呼叫,custom_fwd將無效,cast_inputs也將不起作用。
- torch.amp.custom_bwd(bwd=None, *, device_type)[source]#
建立用於自定義 autograd 函式的 backward 方法的輔助裝飾器。
Autograd 函式是
torch.autograd.Function的子類。確保backward以與forward相同的自動型別轉換狀態執行。有關更多詳細資訊,請參閱 示例頁面。- 引數
device_type (str) – 要使用的裝置型別。“cuda”、“cpu”、“mtia”、“maia”、“xpu”等。該型別與
torch.device的 type 屬性相同。因此,您可以使用 Tensor.device.type 獲取張量的裝置型別。
- class torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True)[source]#
參見
torch.autocast。torch.cuda.amp.autocast(args...)已棄用。請改用torch.amp.autocast("cuda", args...)。
- torch.cuda.amp.custom_fwd(fwd=None, *, cast_inputs=None)[source]#
torch.cuda.amp.custom_fwd(args...)已棄用。請改用torch.amp.custom_fwd(args..., device_type='cuda')。
- torch.cuda.amp.custom_bwd(bwd)[source]#
torch.cuda.amp.custom_bwd(args...)已棄用。請改用torch.amp.custom_bwd(args..., device_type='cuda')。
- class torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True)[source]#
參見
torch.autocast。torch.cpu.amp.autocast(args...)已棄用。請改用torch.amp.autocast("cpu", args...)。
梯度縮放#
如果特定操作的正向傳播具有 float16 輸入,則該操作的反向傳播將產生 float16 梯度。幅度小的梯度值可能無法在 float16 中表示。這些值將變為零(“下溢”),因此相應引數的更新將丟失。
為防止下溢,“梯度縮放”將網路的損失乘以一個比例因子,並對縮放後的損失呼叫反向傳播。換句話說,梯度值具有更大的幅度,因此它們不會變為零。
在最佳化器更新引數之前,應取消每個引數的梯度(.grad 屬性)的縮放,以免比例因子干擾學習率。
注意
AMP/fp16 可能並非適用於所有模型!例如,大多數 bf16 預訓練的模型無法在最大值為 65504 的 fp16 數值範圍內執行,並且會導致梯度溢位而非下溢。在這種情況下,比例因子可能會減小到 1 以下,以嘗試將梯度恢復到 fp16 動態範圍內可表示的數字。雖然您可能期望比例因子始終大於 1,但我們的 GradScaler 不做此保證以維持效能。如果您在使用 AMP/fp16 時遇到損失或梯度中的 NaN,請驗證您的模型是否相容。
Autocast 操作參考#
操作資格#
以 float64 或非浮點資料型別執行的操作不符合資格,並且無論是否啟用自動型別轉換,都將以這些型別執行。
只有非原地操作和 Tensor 方法才符合資格。在啟用自動型別轉換的區域中允許原地變體和顯式提供 out=... Tensor 的呼叫,但它們不會透過自動型別轉換。例如,在啟用自動型別轉換的區域中,a.addmm(b, c) 可以自動型別轉換,但 a.addmm_(b, c) 和 a.addmm(b, c, out=d) 不能。為了獲得最佳效能和穩定性,請在啟用自動型別轉換的區域中優先使用非原地操作。
使用顯式 dtype=... 引數呼叫的操作不符合資格,並且將產生尊重 dtype 引數的輸出。
CUDA 操作特定行為#
以下列表描述了在啟用自動型別轉換的區域中符合資格的操作的行為。無論這些操作是作為 torch.nn.Module、函式還是 torch.Tensor 方法呼叫,它們都會透過自動型別轉換。如果函式在多個名稱空間中公開,則無論名稱空間如何,它們都將透過自動型別轉換。
下面未列出的操作不會透過自動型別轉換。它們以其輸入定義的資料型別執行。但是,如果未列出的操作位於自動型別轉換操作的下游,自動型別轉換仍可能更改它們執行的資料型別。
如果一個操作未列出,我們假設它在 float16 中是數值穩定的。如果您認為未列出的操作在 float16 中不具有數值穩定性,請提交一個問題。
可自動轉換為 float16 的 CUDA 操作#
__matmul__, addbmm, addmm, addmv, addr, baddbmm, bmm, chain_matmul, multi_dot, conv1d, conv2d, conv3d, conv_transpose1d, conv_transpose2d, conv_transpose3d, GRUCell, linear, LSTMCell, matmul, mm, mv, prelu, RNNCell
可自動轉換為 float32 的 CUDA 操作#
__pow__, __rdiv__, __rpow__, __rtruediv__, acos, asin, binary_cross_entropy_with_logits, cosh, cosine_embedding_loss, cdist, cosine_similarity, cross_entropy, cumprod, cumsum, dist, erfinv, exp, expm1, group_norm, hinge_embedding_loss, kl_div, l1_loss, layer_norm, log, log_softmax, log10, log1p, log2, margin_ranking_loss, mse_loss, multilabel_margin_loss, multi_margin_loss, nll_loss, norm, normalize, pdist, poisson_nll_loss, pow, prod, reciprocal, rsqrt, sinh, smooth_l1_loss, soft_margin_loss, softmax, softmin, softplus, sum, renorm, tan, triplet_margin_loss
提升至最寬輸入型別的 CUDA 操作#
這些操作不需要特定資料型別即可保持穩定性,但它們接受多個輸入,並要求輸入的 Dtype 匹配。如果所有輸入都是 float16,則操作以 float16 執行。如果任何輸入是 float32,則自動型別轉換會將所有輸入轉換為 float32 並以 float32 執行操作。
addcdiv, addcmul, atan2, bilinear, cross, dot, grid_sample, index_put, scatter_add, tensordot
此處未列出的某些操作(例如,add 等二元操作)會在自動型別轉換干預之前原生提升輸入。如果輸入是 float16 和 float32 的混合,這些操作將以 float32 執行併產生 float32 輸出,無論是否啟用自動型別轉換。
優先使用 binary_cross_entropy_with_logits 而非 binary_cross_entropy#
torch.nn.functional.binary_cross_entropy()(以及包裝它的 torch.nn.BCELoss)的反向傳播會產生無法在 float16 中表示的梯度。在啟用自動型別轉換的區域中,正向輸入可能是 float16,這意味著反向梯度必須在 float16 中表示(將 float16 正向輸入自動型別轉換為 float32 無濟於事,因為這種轉換必須在反向傳播中進行逆轉)。因此,binary_cross_entropy 和 BCELoss 在啟用自動型別轉換的區域中會引發錯誤。
許多模型在二元交叉熵層之前使用 sigmoid 層。在這種情況下,結合這兩個層使用 torch.nn.functional.binary_cross_entropy_with_logits() 或 torch.nn.BCEWithLogitsLoss。binary_cross_entropy_with_logits 和 BCEWithLogits 可以安全地進行自動型別轉換。
XPU 操作特定行為(實驗性)#
以下列表描述了在啟用自動型別轉換的區域中符合資格的操作的行為。無論這些操作是作為 torch.nn.Module、函式還是 torch.Tensor 方法呼叫,它們都會透過自動型別轉換。如果函式在多個名稱空間中公開,則無論名稱空間如何,它們都將透過自動型別轉換。
下面未列出的操作不會透過自動型別轉換。它們以其輸入定義的資料型別執行。但是,如果未列出的操作位於自動型別轉換操作的下游,自動型別轉換仍可能更改它們執行的資料型別。
如果一個操作未列出,我們假設它在 float16 中是數值穩定的。如果您認為未列出的操作在 float16 中不具有數值穩定性,請提交一個問題。
可自動轉換為 float16 的 XPU 操作#
addbmm, addmm, addmv, addr, baddbmm, bmm, chain_matmul, multi_dot, conv1d, conv2d, conv3d, conv_transpose1d, conv_transpose2d, conv_transpose3d, GRUCell, linear, LSTMCell, matmul, mm, mv, RNNCell
可自動轉換為 float32 的 XPU 操作#
__pow__, __rdiv__, __rpow__, __rtruediv__, binary_cross_entropy_with_logits, cosine_embedding_loss, cosine_similarity, cumsum, dist, exp, group_norm, hinge_embedding_loss, kl_div, l1_loss, layer_norm, log, log_softmax, margin_ranking_loss, nll_loss, normalize, poisson_nll_loss, pow, reciprocal, rsqrt, soft_margin_loss, softmax, softmin, sum, triplet_margin_loss
提升至最寬輸入型別的 XPU 操作#
這些操作不需要特定資料型別即可保持穩定性,但它們接受多個輸入,並要求輸入的 Dtype 匹配。如果所有輸入都是 float16,則操作以 float16 執行。如果任何輸入是 float32,則自動型別轉換會將所有輸入轉換為 float32 並以 float32 執行操作。
bilinear, cross, grid_sample, index_put, scatter_add, tensordot
此處未列出的某些操作(例如,add 等二元操作)會在自動型別轉換干預之前原生提升輸入。如果輸入是 float16 和 float32 的混合,這些操作將以 float32 執行併產生 float32 輸出,無論是否啟用自動型別轉換。
CPU 操作特定行為#
以下列表描述了在啟用自動型別轉換的區域中符合資格的操作的行為。無論這些操作是作為 torch.nn.Module、函式還是 torch.Tensor 方法呼叫,它們都會透過自動型別轉換。如果函式在多個名稱空間中公開,則無論名稱空間如何,它們都將透過自動型別轉換。
下面未列出的操作不會透過自動型別轉換。它們以其輸入定義的資料型別執行。但是,如果未列出的操作位於自動型別轉換操作的下游,自動型別轉換仍可能更改它們執行的資料型別。
如果操作未列出,我們假設它在 bfloat16 中是數值穩定的。如果您認為未列出的操作在 bfloat16 中不具有數值穩定性,請提交一個問題。float16 與 bfloat16 的列表相同。
可自動轉換為 bfloat16 的 CPU 操作#
conv1d, conv2d, conv3d, bmm, mm, linalg_vecdot, baddbmm, addmm, addbmm, linear, matmul, _convolution, conv_tbc, mkldnn_rnn_layer, conv_transpose1d, conv_transpose2d, conv_transpose3d, prelu, scaled_dot_product_attention, _native_multi_head_attention
可自動轉換為 float32 的 CPU 操作#
avg_pool3d, binary_cross_entropy, grid_sampler, grid_sampler_2d, _grid_sampler_2d_cpu_fallback, grid_sampler_3d, polar, prod, quantile, nanquantile, stft, cdist, trace, view_as_complex, cholesky, cholesky_inverse, cholesky_solve, inverse, lu_solve, orgqr, inverse, ormqr, pinverse, max_pool3d, max_unpool2d, max_unpool3d, adaptive_avg_pool3d, reflection_pad1d, reflection_pad2d, replication_pad1d, replication_pad2d, replication_pad3d, mse_loss, cosine_embedding_loss, nll_loss, nll_loss2d, hinge_embedding_loss, poisson_nll_loss, cross_entropy_loss, l1_loss, huber_loss, margin_ranking_loss, soft_margin_loss, triplet_margin_loss, multi_margin_loss, ctc_loss, kl_div, multilabel_margin_loss, binary_cross_entropy_with_logits, fft_fft, fft_ifft, fft_fft2, fft_ifft2, fft_fftn, fft_ifftn, fft_rfft, fft_irfft, fft_rfft2, fft_irfft2, fft_rfftn, fft_irfftn, fft_hfft, fft_ihfft, linalg_cond, linalg_matrix_rank, linalg_solve, linalg_cholesky, linalg_svdvals, linalg_eigvals, linalg_eigvalsh, linalg_inv, linalg_householder_product, linalg_tensorinv, linalg_tensorsolve, fake_quantize_per_tensor_affine, geqrf, _lu_with_info, qr, svd, triangular_solve, fractional_max_pool2d, fractional_max_pool3d, adaptive_max_pool3d, multilabel_margin_loss_forward, linalg_qr, linalg_cholesky_ex, linalg_svd, linalg_eig, linalg_eigh, linalg_lstsq, linalg_inv_ex
提升至最寬輸入型別的 CPU 操作#
這些操作不需要特定資料型別即可保持穩定性,但它們接受多個輸入,並要求輸入的 Dtype 匹配。如果所有輸入都是 bfloat16,則操作以 bfloat16 執行。如果任何輸入是 float32,則自動型別轉換會將所有輸入轉換為 float32 並以 float32 執行操作。
cat, stack, index_copy
此處未列出的某些操作(例如,add 等二元操作)會在自動型別轉換干預之前原生提升輸入。如果輸入是 bfloat16 和 float32 的混合,這些操作將以 float32 執行併產生 float32 輸出,無論是否啟用自動型別轉換。