評價此頁

MKLDNN 後端#

建立日期:2025 年 5 月 10 日 | 最後更新日期:2025 年 7 月 17 日

MKLDNN 是一個開源的跨平臺效能庫,包含深度學習應用程式的基本構建塊。

# The flag below controls whether enable MKLDNN backend in Pytorch.
torch.backends.mkldnn.enabled = True

使用者可以透過以下方式停用 MKLDNN 後端:

torch.backends.mkldnn.enabled = False

MKLDNN 後端的 Bfloat16 (BF16)#

從 PyTorch 2.9 開始,提供了一組 API 來控制 float32 運算子的內部計算精度。

# The flag below controls the internal computation precision for mkldnn matmul. Default ieee is float32.
torch.backends.mkldnn.matmul.fp32_precision = "ieee"

# The flag below controls the internal computation precision for mkldnn conv. Default ieee is float32.
torch.backends.mkldnn.conv.fp32_precision = "ieee"

# The flag below controls the internal computation precision for mkldnn rnn. Default ieee is float32.
torch.backends.mkldnn.rnn.fp32_precision = "ieee"

請注意,除了 matmuls 和 convolutions 本身之外,內部使用 matmuls 或 convolutions 的函式和 nn 模組也會受到影響。這些包括 torch.nn.Lineartorch.nn._ConvNdtorch.cdist()torch.tensordot()torch.nn.functional.affine_grid()torch.nn.functional.grid_sample()torch.nn.AdaptiveLogSoftmaxWithLosstorch.nn.GRUtorch.nn.LSTM

為了瞭解精度和速度,請參見下面的示例程式碼和基準測試資料(在 SPR 上)。

torch.manual_seed(0)
a_full = torch.randn(10240, 10240, dtype=torch.double)
b_full = torch.randn(10240, 10240, dtype=torch.double)
ab_full = a_full @ b_full
mean = ab_full.abs().mean()  # 80.7451

a = a_full.float()
b = b_full.float()

# Do matmul at BF16 mode.
torch.backends.mkldnn.matmul.fp32_precision = 'bf16'
ab_bf16 = a @ b  # expected speedup with BF16 dot-product acceleration
error = (ab_bf16 - ab_full).abs().max()  # 1.3704
relative_error = error / mean  # 0.0170
print(error, relative_error)

# Do matmul at TF32 mode.
torch.backends.mkldnn.matmul.fp32_precision = 'tf32'
ab_tf32 = a @ b  # expected speedup with TF32 dot-product acceleration
error = (ab_tf32 - ab_full).abs().max()  # 0.0004
relative_error = error / mean  # 0.00000552
print(error, relative_error)

# Do matmul FP32 mode.
torch.backends.mkldnn.matmul.fp32_precision = 'ieee'
ab_fp32 = a @ b
error = (ab_fp32 - ab_full).abs().max()  # 0.0003
relative_error = error / mean  # 0.00000317
print(error, relative_error)

從上面的示例可以看出,使用 BF16,在 SPR 上的速度提高了約 7 倍,與雙精度相比,相對誤差大約大 2 個數量級。如果需要完整的 FP32 精度,使用者可以透過以下方式停用 BF16:

torch.backends.mkldnn.matmul.fp32_precision = 'ieee'
torch.backends.mkldnn.conv.fp32_precision = 'ieee'
torch.backends.mkldnn.rnn.fp32_precision = 'ieee'

要停用 C++ 中的 BF16 標誌,您可以執行以下操作:

at::globalContext().setFloat32Precision("ieee", "mkldnn", "matmul");
at::globalContext().setFloat32Precision("ieee", "mkldnn", "conv");
at::globalContext().setFloat32Precision("ieee", "mkldnn", "rnn");

如果 fp32_precision 設定為 ieee,我們可以為特定的運算子或後端覆蓋通用設定。

torch.backends.fp32_precision = "bf16"
torch.backends.mkldnn.fp32_precision = "ieee"
torch.backends.mkldnn.matmul.fp32_precision = "ieee"

在這種情況下,torch.backends.mkldnn.fp32_precisiontorch.backends.mkldnn.matmul.fp32_precision 都將被覆蓋為 bf16。