注意
跳轉到末尾 下載完整的示例程式碼。
使用使用者定義的 Triton 核心配合 torch.compile#
創建於: 2024 年 4 月 19 日 | 最後更新: 2025 年 5 月 2 日 | 最後驗證: 2024 年 11 月 5 日
作者: Oguz Ulgen
使用者定義的 Triton 核心可用於最佳化模型計算的特定部分。這些核心是用 Triton 語言編寫的,該語言旨在更容易地實現硬體的峰值效能。透過將使用者定義的 Triton 核心與 torch.compile 一起使用,您可以將這些最佳化的計算整合到您的 PyTorch 模型中,從而可能獲得顯著的效能提升。
本示例演示瞭如何將使用者定義的 Triton 核心與 torch.compile 一起使用。
先決條件#
在開始此秘籍之前,請確保您已具備以下條件
對
torch.compile和 Triton 的基本理解。參見PyTorch 2.3 或更高版本
支援 Triton 的 GPU
import torch
from torch.utils._triton import has_triton
基本用法#
在本示例中,我們將使用 Triton 文件中的一個簡單的向量加法核心與 torch.compile 結合使用。作為參考,請參閱 Triton 文件。
if not has_triton():
print("Skipping because triton is not supported on this device.")
else:
import triton
from triton import language as tl
@triton.jit
def add_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@torch.compile(fullgraph=True)
def add_fn(x, y):
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)
return output
x = torch.randn(4, device="cuda")
y = torch.randn(4, device="cuda")
out = add_fn(x, y)
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
Vector addition of
X: tensor([-1.2986, 0.5554, 2.0662, -0.0785], device='cuda:0')
Y: tensor([ 0.1402, 1.0969, -0.6538, 0.1946], device='cuda:0')
is equal to
tensor([-1.1584, 1.6523, 1.4124, 0.1160], device='cuda:0')
高階用法#
Triton 的自動調優功能是一個強大的工具,可自動最佳化 Triton 核心的配置引數。它會探索一系列可能的配置,並選擇最適合您特定用例的配置。
當與 torch.compile 一起使用時,triton.autotune 可以幫助確保您的 PyTorch 模型以儘可能高的效率執行。以下是使用 torch.compile 和 triton.autotune 的示例。
注意
torch.compile 僅支援 triton.autotune 的配置和關鍵字引數。
if not has_triton():
print("Skipping because triton is not supported on this device.")
else:
import triton
from triton import language as tl
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4),
],
key=[],
)
@triton.jit
def add_kernel_autotuned(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@torch.compile(fullgraph=True)
def add_fn(x, y):
output = torch.zeros_like(x)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
add_kernel_autotuned[grid](x, y, output, n_elements)
return output
x = torch.randn(4, device="cuda")
y = torch.randn(4, device="cuda")
out = add_fn(x, y)
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
Vector addition of
X: tensor([0.9884, 1.0321, 0.0053, 0.6395], device='cuda:0')
Y: tensor([-0.8745, 0.7483, -2.6247, -0.2872], device='cuda:0')
is equal to
tensor([ 0.1139, 1.7804, -2.6194, 0.3523], device='cuda:0')
可組合性#
使用者定義的 Triton 核心並不自動支援所有 PyTorch 子系統。這在以下用例中可以看到:
新增 CPU 回退
新增
FlopCounter公式與 Tensor 子類組合
要與額外的 PyTorch 子系統組合,請使用 torch.library.triton_op。
triton_op 是一種定義自定義運算子的結構化方法,該運算子由一個或多個 Triton 核心支援:與常規自定義運算子(torch.library.custom_op)一樣,您可以透過 torch.library 指定與 PyTorch 子系統的互動。然而,與 torch.library.custom_op(它建立了相對於 torch.compile 不透明的可呼叫物件)不同,torch.compile 會跟蹤 triton_op 以應用最佳化。
這是將 Triton 核心與 PyTorch 整合時使用哪個 API 的圖表。
Triton 核心(無顯式 |
|
|
|
|---|---|---|---|
支援推理 |
是 |
是 |
是 |
支援訓練 |
在大多數情況下 |
是 |
是 |
支援 |
是 |
是 |
是 |
支援 |
在大多數情況下 |
在大多數情況下 |
在所有情況下 |
torch.compile 是否跟蹤實現? |
是 |
是 |
否 |
支援 AOTInductor |
是 |
是 |
否 |
支援 FlopCounterMode、CPU 回退、Tensor 子類等 PyTorch 子系統 |
否 |
是 |
是 |
使用 triton_op 包裝 Triton 核心#
使用 torch.library.triton_op 來包裝可能呼叫一個或多個 Triton 核心的函式。使用 torch.library.wrap_triton 來包裝對 Triton 核心的呼叫。
from torch.library import triton_op, wrap_triton
@triton_op("mylib::mysin", mutates_args={})
def mysin(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
n_elements = x.numel()
wrap_triton(sin_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
return out
@triton.jit
def sin_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
output = tl.sin(x)
tl.store(out_ptr + offsets, output, mask=mask)
您可以透過以下兩種方式之一呼叫 triton_op。
x = torch.randn(3, device="cuda")
y = mysin(x)
z = torch.ops.mylib.mysin.default(x)
assert torch.allclose(y, x.sin())
assert torch.allclose(z, x.sin())
生成的 triton_op 可與 torch.compile 和 AOTInductor 一起使用。
y = torch.compile(mysin)(x)
assert torch.allclose(y, x.sin())
新增訓練支援#
使用 register_autograd 為 triton_op 新增自動求導公式。優先使用此方法,而不是 torch.autograd.Function(它與 torch.compile 存在各種組合陷阱)。
請注意,後向傳播必須是 PyTorch 可理解運算子的組合。如果您希望後向傳播呼叫 Triton 核心,那麼這些核心也必須用 triton_op 包裝。
@triton.jit
def cos_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
output = tl.cos(x)
tl.store(out_ptr + offsets, output, mask=mask)
@triton_op("mylib::mycos", mutates_args={})
def mycos(x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
n_elements = x.numel()
wrap_triton(cos_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
return out
def backward(ctx, grad):
x, = ctx.saved_tensors
return grad * mycos(x)
def setup_context(ctx, inputs, output):
x, = inputs
ctx.save_for_backward(x)
mysin.register_autograd(backward, setup_context=setup_context)
新增 CPU 回退#
Triton 核心不在 CPU 上執行。使用 register_kernel 為 triton_op 新增 CPU(或任何其他裝置)回退。
回退必須由 PyTorch 運算子組成。
新增 FlopCounter 公式#
要指定 Triton 核心在 PyTorch 的浮點運算計數器下報告多少次浮點運算,請使用 register_flop_formula。
from torch.utils.flop_counter import FlopCounterMode, register_flop_formula
@register_flop_formula(torch.ops.mylib.mysin)
def _(x_shape):
numel = 1
for s in x_shape:
numel *= s
return numel
x = torch.randn(3, device="cuda")
FlopCounterMode 需要 tabulate。在執行以下程式碼之前,請確保您已安裝 tabulate,或者透過執行 pip install tabulate 進行安裝。
侷限性#
截至 PyTorch 2.3,torch.compile 對使用者定義的 Triton 核心的支援包括動態形狀、torch.autograd.Function、JIT Inductor 和 AOT Inductor。您可以將這些功能結合使用來構建複雜的高效能模型。
PyTorch 2.6 添加了 torch.library.triton_op,它增加了對 Tensor 子類和其他高階功能中的使用者定義 Triton 核心的支援。
但是,有一些限制需要注意:
Triton 功能: 雖然
triton.heuristics可以獨立使用,或者在triton.autotune之前使用,但不能在triton.autotune之後使用。這意味著如果triton.heuristics和triton.autotune要一起使用,則必須先使用triton.heuristics。
結論#
在此示例中,我們探討了如何將使用者定義的 Triton 核心與 torch.compile 一起使用。我們深入研究了一個簡單的向量加法核心的基本用法以及涉及 Triton 自動調優功能的高階用法。我們還討論了使用者定義的 Triton 核心與其他 PyTorch 功能的可組合性,並強調了一些當前的侷限性。
另請參閱#
指令碼總執行時間: (0 分 3.575 秒)