自定義後端#
創建於: 2025 年 6 月 10 日 | 最後更新於: 2025 年 6 月 10 日
概述#
torch.compile 提供了一種直接的方法,使使用者能夠定義自定義後端。
後端函式具有 (gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]) -> Callable 的約定。
在跟蹤完 FX 圖後,TorchDynamo(torch.compile 的圖跟蹤元件)可以呼叫後端函式,並期望它返回一個與跟蹤的 FX 圖等效的編譯後的函式。返回的可呼叫物件應該與傳遞給後端的 torch.fx.GraphModule 的 forward 函式具有相同的約定:(*args: torch.Tensor) -> List[torch.Tensor]。
為了讓 TorchDynamo 呼叫您的後端,請在 torch.compile 中將您的後端函式作為 backend 關鍵字引數傳遞。例如,
import torch
def my_custom_backend(gm, example_inputs):
return gm.forward
def f(...):
...
f_opt = torch.compile(f, backend=my_custom_backend)
@torch.compile(backend=my_custom_backend)
def g(...):
...
更多示例請參見下文。
註冊自定義後端#
您可以使用 register_backend 裝飾器註冊您的後端,例如,
from torch._dynamo import register_backend
@register_backend
def my_compiler(gm, example_inputs):
...
除了 register_backend 裝飾器之外,如果您的後端位於另一個 Python 包中,您還可以透過 Python 包的入口點註冊您的後端,這為包提供了一種為另一個包註冊外掛的方式。
提示
您可以在 Python 打包文件 中瞭解更多關於 entry_points 的資訊。
要透過 entry_points 註冊您的後端,您可以在包的 setup.py 檔案中將您的後端函式新增到 torch_dynamo_backends 入口點組,如下所示:
...
setup(
...
'torch_dynamo_backends': [
'my_compiler = your_module.submodule:my_compiler',
]
...
)
請將 my_compiler(等號前)替換為您的後端的名稱,並將等號後的部分替換為您的後端函式的模組和函式名。入口點將在安裝包後新增到您的 Python 環境中。當您呼叫 torch.compile(model, backend="my_compiler") 時,PyTorch 會首先搜尋已透過 register_backend 註冊的名為 my_compiler 的後端。如果未找到,它將繼續搜尋透過 entry_points 註冊的所有後端。
註冊有兩個目的
您可以將包含後端函式名稱的字串傳遞給
torch.compile,而不是函式本身,例如,torch.compile(model, backend="my_compiler")。對於使用 minifier 是必需的。minifier 生成的任何程式碼都必須呼叫註冊您的後端函式的程式碼,通常是透過 `import` 語句。
AOTAutograd 後的自定義後端#
可以定義由 AOTAutograd 而非 TorchDynamo 呼叫的自定義後端。這主要有兩個原因:
使用者可以定義支援模型訓練的後端,因為 AOTAutograd 可以生成用於編譯的後向圖。
AOTAutograd 生成由 核心 Aten 運算元組成的 FX 圖。因此,自定義後端只需要支援核心 Aten 運算元集,這比整個 torch/Aten 運算元集要小得多。
將您的後端包裝在 torch._dynamo.backends.common.aot_autograd 中,並像以前一樣使用帶有 backend 關鍵字引數的 torch.compile。由 aot_autograd 包裝的後端函式應該具有與以前相同的約定。
後端函式透過 fw_compiler(前向編譯器)或 bw_compiler(後向編譯器)關鍵字引數傳遞給 aot_autograd。如果未指定 bw_compiler,則後向編譯函式預設為前向編譯函式。
一個注意事項是,AOTAutograd 要求後端返回的編譯函式是“裝箱的”。這可以透過使用 functorch.compile.make_boxed_func 來包裝編譯後的函式來完成。
例如,
from torch._dynamo.backends.common import aot_autograd
from functorch.compile import make_boxed_func
def my_compiler(gm, example_inputs):
return make_boxed_func(gm.forward)
my_backend = aot_autograd(fw_compiler=my_compiler) # bw_compiler=my_compiler
model_opt = torch.compile(model, backend=my_backend)
示例#
除錯後端#
如果您想更好地瞭解編譯過程中發生了什麼,可以建立一個自定義編譯器(在本節中稱為後端),它將列印 Dynamo 的位元組碼分析提取的 FX GraphModule 的漂亮列印版本,並返回一個 forward() 可呼叫物件。
例如
from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@torch.compile(backend=my_compiler)
def fn(x, y):
a = torch.cos(x)
b = torch.sin(y)
return a + b
fn(torch.randn(10), torch.randn(10))
執行上述示例將產生以下輸出:
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ------------------------------------------------------ ---------- --------
placeholder x x () {}
placeholder y y () {}
call_function cos <built-in method cos of type object at 0x7f1a894649a8> (x,) {}
call_function sin <built-in method sin of type object at 0x7f1a894649a8> (y,) {}
call_function add <built-in function add> (cos, sin) {}
output output output ((add,),) {}
這對於 torch.nn.Module 也有效,如下所示:
from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(torch.cos(x))
mod = MockModule()
optimized_mod = torch.compile(mod, backend=my_compiler)
optimized_mod(torch.randn(10))
讓我們再看一個帶有控制流的示例:
from typing import List
import torch
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@torch.compile(backend=my_compiler)
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
執行此示例將產生以下輸出:
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------- ------------------------------------------------------ ---------------- --------
placeholder a a () {}
placeholder b b () {}
call_function abs_1 <built-in method abs of type object at 0x7f8d259298a0> (a,) {}
call_function add <built-in function add> (abs_1, 1) {}
call_function truediv <built-in function truediv> (a, add) {}
call_method sum_1 sum (b,) {}
call_function lt <built-in function lt> (sum_1, 0) {}
output output output ((truediv, lt),) {}
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ----------------------- ----------- --------
placeholder b b () {}
placeholder x x () {}
call_function mul <built-in function mul> (b, -1) {}
call_function mul_1 <built-in function mul> (x, mul) {}
output output output ((mul_1,),) {}
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ------ ----------------------- --------- --------
placeholder b b () {}
placeholder x x () {}
call_function mul <built-in function mul> (x, b) {}
output output output ((mul,),) {}
The order of the last two graphs is nondeterministic depending
on which one is encountered first by the just-in-time compiler.
加速後端#
整合提供卓越效能的自定義後端也非常簡單,我們將整合一個真實的後端,並使用 optimize_for_inference:
def optimize_for_inference_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
scripted = torch.jit.script(gm)
return torch.jit.optimize_for_inference(scripted)
然後您就可以使用以下方式最佳化任何現有程式碼:
@torch.compile(backend=optimize_for_inference_compiler)
def code_to_accelerate():
...
可組合後端#
TorchDynamo 包含許多後端,可以使用 torch._dynamo.list_backends() 列出。您可以使用以下程式碼將這些後端組合在一起:
from torch._dynamo import lookup_backend
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
try:
trt_compiled = lookup_backend("tensorrt")(gm, example_inputs)
if trt_compiled is not None:
return trt_compiled
except Exception:
pass
# first backend failed, try something else...
try:
inductor_compiled = lookup_backend("inductor")(gm, example_inputs)
if inductor_compiled is not None:
return inductor_compiled
except Exception:
pass
return gm.forward