MKLDNN 後端#
建立日期:2025 年 5 月 10 日 | 最後更新日期:2025 年 7 月 17 日
MKLDNN 是一個開源的跨平臺效能庫,包含深度學習應用程式的基本構建塊。
# The flag below controls whether enable MKLDNN backend in Pytorch.
torch.backends.mkldnn.enabled = True
使用者可以透過以下方式停用 MKLDNN 後端:
torch.backends.mkldnn.enabled = False
MKLDNN 後端的 Bfloat16 (BF16)#
從 PyTorch 2.9 開始,提供了一組 API 來控制 float32 運算子的內部計算精度。
# The flag below controls the internal computation precision for mkldnn matmul. Default ieee is float32.
torch.backends.mkldnn.matmul.fp32_precision = "ieee"
# The flag below controls the internal computation precision for mkldnn conv. Default ieee is float32.
torch.backends.mkldnn.conv.fp32_precision = "ieee"
# The flag below controls the internal computation precision for mkldnn rnn. Default ieee is float32.
torch.backends.mkldnn.rnn.fp32_precision = "ieee"
請注意,除了 matmuls 和 convolutions 本身之外,內部使用 matmuls 或 convolutions 的函式和 nn 模組也會受到影響。這些包括 torch.nn.Linear、torch.nn._ConvNd、torch.cdist()、torch.tensordot()、torch.nn.functional.affine_grid() 和 torch.nn.functional.grid_sample()、torch.nn.AdaptiveLogSoftmaxWithLoss、torch.nn.GRU 和 torch.nn.LSTM。
為了瞭解精度和速度,請參見下面的示例程式碼和基準測試資料(在 SPR 上)。
torch.manual_seed(0)
a_full = torch.randn(10240, 10240, dtype=torch.double)
b_full = torch.randn(10240, 10240, dtype=torch.double)
ab_full = a_full @ b_full
mean = ab_full.abs().mean() # 80.7451
a = a_full.float()
b = b_full.float()
# Do matmul at BF16 mode.
torch.backends.mkldnn.matmul.fp32_precision = 'bf16'
ab_bf16 = a @ b # expected speedup with BF16 dot-product acceleration
error = (ab_bf16 - ab_full).abs().max() # 1.3704
relative_error = error / mean # 0.0170
print(error, relative_error)
# Do matmul at TF32 mode.
torch.backends.mkldnn.matmul.fp32_precision = 'tf32'
ab_tf32 = a @ b # expected speedup with TF32 dot-product acceleration
error = (ab_tf32 - ab_full).abs().max() # 0.0004
relative_error = error / mean # 0.00000552
print(error, relative_error)
# Do matmul FP32 mode.
torch.backends.mkldnn.matmul.fp32_precision = 'ieee'
ab_fp32 = a @ b
error = (ab_fp32 - ab_full).abs().max() # 0.0003
relative_error = error / mean # 0.00000317
print(error, relative_error)
從上面的示例可以看出,使用 BF16,在 SPR 上的速度提高了約 7 倍,與雙精度相比,相對誤差大約大 2 個數量級。如果需要完整的 FP32 精度,使用者可以透過以下方式停用 BF16:
torch.backends.mkldnn.matmul.fp32_precision = 'ieee'
torch.backends.mkldnn.conv.fp32_precision = 'ieee'
torch.backends.mkldnn.rnn.fp32_precision = 'ieee'
要停用 C++ 中的 BF16 標誌,您可以執行以下操作:
at::globalContext().setFloat32Precision("ieee", "mkldnn", "matmul");
at::globalContext().setFloat32Precision("ieee", "mkldnn", "conv");
at::globalContext().setFloat32Precision("ieee", "mkldnn", "rnn");
如果 fp32_precision 設定為 ieee,我們可以為特定的運算子或後端覆蓋通用設定。
torch.backends.fp32_precision = "bf16"
torch.backends.mkldnn.fp32_precision = "ieee"
torch.backends.mkldnn.matmul.fp32_precision = "ieee"
在這種情況下,torch.backends.mkldnn.fp32_precision 和 torch.backends.mkldnn.matmul.fp32_precision 都將被覆蓋為 bf16。