評價此頁
torch.compile">

使用使用者定義的 Triton 核心配合 torch.compile#

創建於: 2024 年 4 月 19 日 | 最後更新: 2025 年 5 月 2 日 | 最後驗證: 2024 年 11 月 5 日

作者: Oguz Ulgen

使用者定義的 Triton 核心可用於最佳化模型計算的特定部分。這些核心是用 Triton 語言編寫的,該語言旨在更容易地實現硬體的峰值效能。透過將使用者定義的 Triton 核心與 torch.compile 一起使用,您可以將這些最佳化的計算整合到您的 PyTorch 模型中,從而可能獲得顯著的效能提升。

本示例演示瞭如何將使用者定義的 Triton 核心與 torch.compile 一起使用。

先決條件#

在開始此秘籍之前,請確保您已具備以下條件

import torch
from torch.utils._triton import has_triton

基本用法#

在本示例中,我們將使用 Triton 文件中的一個簡單的向量加法核心與 torch.compile 結合使用。作為參考,請參閱 Triton 文件

if not has_triton():
    print("Skipping because triton is not supported on this device.")
else:
    import triton
    from triton import language as tl

    @triton.jit
    def add_kernel(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        output = x + y
        tl.store(out_ptr + offsets, output, mask=mask)

    @torch.compile(fullgraph=True)
    def add_fn(x, y):
        output = torch.zeros_like(x)
        n_elements = output.numel()
        grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
        add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)
        return output

    x = torch.randn(4, device="cuda")
    y = torch.randn(4, device="cuda")
    out = add_fn(x, y)
    print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
Vector addition of
X:      tensor([-1.2986,  0.5554,  2.0662, -0.0785], device='cuda:0')
Y:      tensor([ 0.1402,  1.0969, -0.6538,  0.1946], device='cuda:0')
is equal to
tensor([-1.1584,  1.6523,  1.4124,  0.1160], device='cuda:0')

高階用法#

Triton 的自動調優功能是一個強大的工具,可自動最佳化 Triton 核心的配置引數。它會探索一系列可能的配置,並選擇最適合您特定用例的配置。

當與 torch.compile 一起使用時,triton.autotune 可以幫助確保您的 PyTorch 模型以儘可能高的效率執行。以下是使用 torch.compiletriton.autotune 的示例。

注意

torch.compile 僅支援 triton.autotune 的配置和關鍵字引數。

if not has_triton():
    print("Skipping because triton is not supported on this device.")
else:
    import triton
    from triton import language as tl

    @triton.autotune(
        configs=[
            triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8),
            triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4),
            triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8),
            triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4),
        ],
        key=[],
    )
    @triton.jit
    def add_kernel_autotuned(
        in_ptr0,
        in_ptr1,
        out_ptr,
        n_elements,
        BLOCK_SIZE: "tl.constexpr",
    ):
        pid = tl.program_id(axis=0)
        block_start = pid * BLOCK_SIZE
        offsets = block_start + tl.arange(0, BLOCK_SIZE)
        mask = offsets < n_elements
        x = tl.load(in_ptr0 + offsets, mask=mask)
        y = tl.load(in_ptr1 + offsets, mask=mask)
        output = x + y
        tl.store(out_ptr + offsets, output, mask=mask)

    @torch.compile(fullgraph=True)
    def add_fn(x, y):
        output = torch.zeros_like(x)
        n_elements = output.numel()
        grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
        add_kernel_autotuned[grid](x, y, output, n_elements)
        return output

    x = torch.randn(4, device="cuda")
    y = torch.randn(4, device="cuda")
    out = add_fn(x, y)
    print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
Vector addition of
X:      tensor([0.9884, 1.0321, 0.0053, 0.6395], device='cuda:0')
Y:      tensor([-0.8745,  0.7483, -2.6247, -0.2872], device='cuda:0')
is equal to
tensor([ 0.1139,  1.7804, -2.6194,  0.3523], device='cuda:0')

可組合性#

使用者定義的 Triton 核心並不自動支援所有 PyTorch 子系統。這在以下用例中可以看到:

  • 新增 CPU 回退

  • 新增 FlopCounter 公式

  • 與 Tensor 子類組合

要與額外的 PyTorch 子系統組合,請使用 torch.library.triton_op

triton_op 是一種定義自定義運算子的結構化方法,該運算子由一個或多個 Triton 核心支援:與常規自定義運算子(torch.library.custom_op)一樣,您可以透過 torch.library 指定與 PyTorch 子系統的互動。然而,與 torch.library.custom_op(它建立了相對於 torch.compile 不透明的可呼叫物件)不同,torch.compile 會跟蹤 triton_op 以應用最佳化。

這是將 Triton 核心與 PyTorch 整合時使用哪個 API 的圖表。

Triton 核心(無顯式 torch.library 包裝器)

torch.library.triton_op

torch.library.custom_op

支援推理

支援訓練

在大多數情況下

支援 torch.compile

支援 torch.compile(fullgraph=True)

在大多數情況下

在大多數情況下

在所有情況下

torch.compile 是否跟蹤實現?

支援 AOTInductor

支援 FlopCounterMode、CPU 回退、Tensor 子類等 PyTorch 子系統

使用 triton_op 包裝 Triton 核心#

使用 torch.library.triton_op 來包裝可能呼叫一個或多個 Triton 核心的函式。使用 torch.library.wrap_triton 來包裝對 Triton 核心的呼叫。

from torch.library import triton_op, wrap_triton

@triton_op("mylib::mysin", mutates_args={})
def mysin(x: torch.Tensor) -> torch.Tensor:
    out = torch.empty_like(x)
    n_elements = x.numel()
    wrap_triton(sin_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
    return out

@triton.jit
def sin_kernel(
    in_ptr0,
    out_ptr,
    n_elements,
    BLOCK_SIZE: "tl.constexpr",
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(in_ptr0 + offsets, mask=mask)
    output = tl.sin(x)
    tl.store(out_ptr + offsets, output, mask=mask)

您可以透過以下兩種方式之一呼叫 triton_op

x = torch.randn(3, device="cuda")
y = mysin(x)
z = torch.ops.mylib.mysin.default(x)

assert torch.allclose(y, x.sin())
assert torch.allclose(z, x.sin())

生成的 triton_op 可與 torch.compileAOTInductor 一起使用。

y = torch.compile(mysin)(x)
assert torch.allclose(y, x.sin())

新增訓練支援#

使用 register_autogradtriton_op 新增自動求導公式。優先使用此方法,而不是 torch.autograd.Function(它與 torch.compile 存在各種組合陷阱)。

def backward(ctx, grad):
    x, = ctx.saved_tensors
    return grad * x.cos()

def setup_context(ctx, inputs, output):
    x, = inputs
    ctx.save_for_backward(x)

mysin.register_autograd(backward, setup_context=setup_context)

請注意,後向傳播必須是 PyTorch 可理解運算子的組合。如果您希望後向傳播呼叫 Triton 核心,那麼這些核心也必須用 triton_op 包裝。

@triton.jit
def cos_kernel(
    in_ptr0,
    out_ptr,
    n_elements,
    BLOCK_SIZE: "tl.constexpr",
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(in_ptr0 + offsets, mask=mask)
    output = tl.cos(x)
    tl.store(out_ptr + offsets, output, mask=mask)

@triton_op("mylib::mycos", mutates_args={})
def mycos(x: torch.Tensor) -> torch.Tensor:
    out = torch.empty_like(x)
    n_elements = x.numel()
    wrap_triton(cos_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
    return out

def backward(ctx, grad):
    x, = ctx.saved_tensors
    return grad * mycos(x)

def setup_context(ctx, inputs, output):
    x, = inputs
    ctx.save_for_backward(x)

mysin.register_autograd(backward, setup_context=setup_context)

新增 CPU 回退#

Triton 核心不在 CPU 上執行。使用 register_kerneltriton_op 新增 CPU(或任何其他裝置)回退。

@mysin.register_kernel("cpu")
def _(x):
    return torch.sin(x)

x = torch.randn(3)
y = mysin(x)
assert torch.allclose(y, x.sin())

回退必須由 PyTorch 運算子組成。

新增 FlopCounter 公式#

要指定 Triton 核心在 PyTorch 的浮點運算計數器下報告多少次浮點運算,請使用 register_flop_formula

from torch.utils.flop_counter import FlopCounterMode, register_flop_formula

@register_flop_formula(torch.ops.mylib.mysin)
def _(x_shape):
    numel = 1
    for s in x_shape:
        numel *= s
    return numel

x = torch.randn(3, device="cuda")

FlopCounterMode 需要 tabulate。在執行以下程式碼之前,請確保您已安裝 tabulate,或者透過執行 pip install tabulate 進行安裝。

>>> with FlopCounterMode() as flop_counter:
>>>     y = mysin(x)

侷限性#

截至 PyTorch 2.3,torch.compile 對使用者定義的 Triton 核心的支援包括動態形狀、torch.autograd.Function、JIT Inductor 和 AOT Inductor。您可以將這些功能結合使用來構建複雜的高效能模型。

PyTorch 2.6 添加了 torch.library.triton_op,它增加了對 Tensor 子類和其他高階功能中的使用者定義 Triton 核心的支援。

但是,有一些限制需要注意:

  • Triton 功能: 雖然 triton.heuristics 可以獨立使用,或者在 triton.autotune 之前使用,但不能在 triton.autotune 之後使用。這意味著如果 triton.heuristicstriton.autotune 要一起使用,則必須先使用 triton.heuristics

結論#

在此示例中,我們探討了如何將使用者定義的 Triton 核心與 torch.compile 一起使用。我們深入研究了一個簡單的向量加法核心的基本用法以及涉及 Triton 自動調優功能的高階用法。我們還討論了使用者定義的 Triton 核心與其他 PyTorch 功能的可組合性,並強調了一些當前的侷限性。

另請參閱#

指令碼總執行時間: (0 分 3.575 秒)