注意
轉到末尾 下載完整的示例程式碼。
torch.compile 簡介#
創建於: 2023年3月15日 | 最後更新: 2025年10月15日 | 最後驗證: 2024年11月5日
作者: William Wen
torch.compile 是加速 PyTorch 程式碼的新方法!torch.compile 透過將 PyTorch 程式碼 JIT 編譯成最佳化後的核心,讓 PyTorch 程式碼執行得更快,同時只需進行最少的程式碼更改。
torch.compile 透過跟蹤您的 Python 程式碼並查詢 PyTorch 操作來完成此操作。難以跟蹤的程式碼將導致 **圖中斷 (graph break)**,這會丟失最佳化機會,而不是導致錯誤或靜默不正確。
torch.compile 在 PyTorch 2.0 及更高版本中可用。
本簡介涵蓋了 torch.compile 的基本用法,並演示了 torch.compile 相對於我們之前的 PyTorch 編譯器解決方案 TorchScript 的優勢。
有關真實模型的端到端示例,請檢視我們的 torch.compile 端到端教程。
要排查問題並更深入地瞭解如何將 torch.compile 應用於您的程式碼,請檢視 torch.compile 程式設計模型。
內容
本教程所需的 pip 依賴項
torch >= 2.0numpyscipy
系統要求 - C++ 編譯器,例如 g++ - Python 開發包 (python-devel/python-dev)
基本用法#
在本教程中,我們啟用了一些日誌記錄,以幫助我們瞭解 torch.compile 在底層做了什麼。以下程式碼將打印出 torch.compile 跟蹤的 PyTorch 操作。
import torch
torch._logging.set_logs(graph_code=True)
torch.compile 是一個接受任意 Python 函式的裝飾器。
def foo(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(3, 3), torch.randn(3, 3)))
@torch.compile
def opt_foo2(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
print(opt_foo2(torch.randn(3, 3), torch.randn(3, 3)))
TRACED GRAPH
===== __compiled_fn_1_57703c6c_17e9_44be_adf9_87ae8a7f015f =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3, 3][3, 1]cpu", L_y_: "f32[3, 3][3, 1]cpu"):
l_x_ = L_x_
l_y_ = L_y_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:74 in foo, code: a = torch.sin(x)
a: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_); l_x_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:75 in foo, code: b = torch.cos(y)
b: "f32[3, 3][3, 1]cpu" = torch.cos(l_y_); l_y_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:76 in foo, code: return a + b
add: "f32[3, 3][3, 1]cpu" = a + b; a = b = None
return (add,)
tensor([[ 0.0663, 1.8726, 1.0057],
[-0.3487, 0.3188, 0.9310],
[ 1.8560, 0.4513, -0.4614]])
TRACED GRAPH
===== __compiled_fn_3_12712180_e493_4bc2_8b8e_dcdfd783faaa =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3, 3][3, 1]cpu", L_y_: "f32[3, 3][3, 1]cpu"):
l_x_ = L_x_
l_y_ = L_y_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:85 in opt_foo2, code: a = torch.sin(x)
a: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_); l_x_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:86 in opt_foo2, code: b = torch.cos(y)
b: "f32[3, 3][3, 1]cpu" = torch.cos(l_y_); l_y_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:87 in opt_foo2, code: return a + b
add: "f32[3, 3][3, 1]cpu" = a + b; a = b = None
return (add,)
tensor([[ 0.2038, 0.5530, 0.2229],
[-0.3382, 0.5160, -0.0161],
[ 1.7310, 1.3559, 1.2261]])
torch.compile 是遞迴應用的,因此頂級編譯函式內的巢狀函式呼叫也將被編譯。
def inner(x):
return torch.sin(x)
@torch.compile
def outer(x, y):
a = inner(x)
b = torch.cos(y)
return a + b
print(outer(torch.randn(3, 3), torch.randn(3, 3)))
TRACED GRAPH
===== __compiled_fn_5_03c189a8_83d7_41cc_a42b_e8e8d534d682 =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3, 3][3, 1]cpu", L_y_: "f32[3, 3][3, 1]cpu"):
l_x_ = L_x_
l_y_ = L_y_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:98 in inner, code: return torch.sin(x)
a: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_); l_x_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:104 in outer, code: b = torch.cos(y)
b: "f32[3, 3][3, 1]cpu" = torch.cos(l_y_); l_y_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:105 in outer, code: return a + b
add: "f32[3, 3][3, 1]cpu" = a + b; a = b = None
return (add,)
tensor([[ 1.2845, -0.0892, -0.2115],
[ 1.3537, -0.0816, -0.0732],
[-0.3591, 1.5748, 0.7948]])
我們還可以透過呼叫其 .compile() 方法或直接 torch.compile-ing 模組來最佳化 torch.nn.Module 例項。這等同於 torch.compile-ing 模組的 __call__ 方法(該方法間接呼叫 forward)。
t = torch.randn(10, 100)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(3, 3)
def forward(self, x):
return torch.nn.functional.relu(self.lin(x))
mod1 = MyModule()
mod1.compile()
print(mod1(torch.randn(3, 3)))
mod2 = MyModule()
mod2 = torch.compile(mod2)
print(mod2(torch.randn(3, 3)))
TRACED GRAPH
===== __compiled_fn_7_d919aa2b_ce68_443d_ab75_c1f3ad8968a4 =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_self_modules_lin_parameters_weight_: "f32[3, 3][3, 1]cpu", L_self_modules_lin_parameters_bias_: "f32[3][1]cpu", L_x_: "f32[3, 3][3, 1]cpu"):
l_self_modules_lin_parameters_weight_ = L_self_modules_lin_parameters_weight_
l_self_modules_lin_parameters_bias_ = L_self_modules_lin_parameters_bias_
l_x_ = L_x_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:126 in forward, code: return torch.nn.functional.relu(self.lin(x))
linear: "f32[3, 3][3, 1]cpu" = torch._C._nn.linear(l_x_, l_self_modules_lin_parameters_weight_, l_self_modules_lin_parameters_bias_); l_x_ = l_self_modules_lin_parameters_weight_ = l_self_modules_lin_parameters_bias_ = None
relu: "f32[3, 3][3, 1]cpu" = torch.nn.functional.relu(linear); linear = None
return (relu,)
tensor([[0.4863, 0.2575, 0.5411],
[0.1428, 0.0000, 0.3762],
[0.4444, 0.5583, 0.7902]], grad_fn=<CompiledFunctionBackward>)
tensor([[0.0000, 0.0000, 1.4330],
[0.0000, 0.0000, 0.0536],
[0.0000, 0.0000, 0.1456]], grad_fn=<CompiledFunctionBackward>)
演示加速效果#
現在讓我們演示 torch.compile 如何加速一個簡單的 PyTorch 示例。有關更復雜模型的演示,請參閱我們的 torch.compile 端到端教程。
def foo3(x):
y = x + 1
z = torch.nn.functional.relu(y)
u = z * 2
return u
opt_foo3 = torch.compile(foo3)
# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1024
inp = torch.randn(4096, 4096).cuda()
print("compile:", timed(lambda: opt_foo3(inp))[1])
print("eager:", timed(lambda: foo3(inp))[1])
TRACED GRAPH
===== __compiled_fn_9_08a72ca3_c6ee_45c6_a198_0e8c99e7092d =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[4096, 4096][4096, 1]cuda:0"):
l_x_ = L_x_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:147 in foo3, code: y = x + 1
y: "f32[4096, 4096][4096, 1]cuda:0" = l_x_ + 1; l_x_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:148 in foo3, code: z = torch.nn.functional.relu(y)
z: "f32[4096, 4096][4096, 1]cuda:0" = torch.nn.functional.relu(y); y = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:149 in foo3, code: u = z * 2
u: "f32[4096, 4096][4096, 1]cuda:0" = z * 2; z = None
return (u,)
compile: 0.40412646532058716
eager: 0.02964000031352043
請注意,與 eager 模式相比,torch.compile 的完成時間似乎要長得多。這是因為 torch.compile 在前幾次執行時需要額外的時間來編譯模型。torch.compile 在可能的情況下重用編譯後的程式碼,因此如果我們再執行幾次最佳化後的模型,與 eager 模式相比,我們應該會看到顯著的改進。
# turn off logging for now to prevent spam
torch._logging.set_logs(graph_code=False)
eager_times = []
for i in range(10):
_, eager_time = timed(lambda: foo3(inp))
eager_times.append(eager_time)
print(f"eager time {i}: {eager_time}")
print("~" * 10)
compile_times = []
for i in range(10):
_, compile_time = timed(lambda: opt_foo3(inp))
compile_times.append(compile_time)
print(f"compile time {i}: {compile_time}")
print("~" * 10)
import numpy as np
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert speedup > 1
print(
f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x"
)
print("~" * 10)
eager time 0: 0.00088900001719594
eager time 1: 0.0008459999808110297
eager time 2: 0.0008459999808110297
eager time 3: 0.0008479999960400164
eager time 4: 0.000846999988425523
eager time 5: 0.0008420000085607171
eager time 6: 0.0008420000085607171
eager time 7: 0.0008509375038556755
eager time 8: 0.0008399999933317304
eager time 9: 0.0008440000237897038
~~~~~~~~~~
compile time 0: 0.0005019999807700515
compile time 1: 0.0003699999942909926
compile time 2: 0.00036100001307204366
compile time 3: 0.0003539999888744205
compile time 4: 0.00035700001171790063
compile time 5: 0.0003530000103637576
compile time 6: 0.0003530000103637576
compile time 7: 0.0003499999875202775
compile time 8: 0.0003539999888744205
compile time 9: 0.0003530000103637576
~~~~~~~~~~
(eval) eager median: 0.0008459999808110297, compile median: 0.0003539999888744205, speedup: 2.389830529376495x
~~~~~~~~~~
事實上,我們可以看到,使用 torch.compile 執行我們的模型可以顯著加速。加速主要來自於減少 Python 開銷和 GPU 讀寫,因此觀察到的加速效果可能會因模型架構和批次大小等因素而異。例如,如果模型的架構很簡單,資料量很大,那麼瓶頸將是 GPU 計算,觀察到的加速效果可能會不那麼顯著。
要檢視真實模型的加速效果,請檢視我們的 torch.compile 端到端教程。
相比 TorchScript 的優勢#
為什麼我們應該使用 torch.compile 而不是 TorchScript?主要而言,torch.compile 的優勢在於其能夠以最少的程式碼更改處理任意 Python 程式碼。
與 TorchScript 相比,TorchScript 具有跟蹤模式 (torch.jit.trace) 和指令碼模式 (torch.jit.script)。跟蹤模式容易出現靜默不正確,而指令碼模式需要大量的程式碼更改,並且在遇到不受支援的 Python 程式碼時會引發錯誤。
例如,TorchScript 跟蹤在依賴於資料的控制流(下面的 if x.sum() < 0: 行)上會靜默失敗,因為只跟蹤了實際的控制流路徑。相比之下,torch.compile 能夠正確處理它。
def f1(x, y):
if x.sum() < 0:
return -y
return y
# Test that `fn1` and `fn2` return the same result, given the same arguments `args`.
def test_fns(fn1, fn2, args):
out1 = fn1(*args)
out2 = fn2(*args)
return torch.allclose(out1, out2)
inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)
traced_f1 = torch.jit.trace(f1, (inp1, inp2))
print("traced 1, 1:", test_fns(f1, traced_f1, (inp1, inp2)))
print("traced 1, 2:", test_fns(f1, traced_f1, (-inp1, inp2)))
compile_f1 = torch.compile(f1)
print("compile 1, 1:", test_fns(f1, compile_f1, (inp1, inp2)))
print("compile 1, 2:", test_fns(f1, compile_f1, (-inp1, inp2)))
print("~" * 10)
/var/lib/workspace/intermediate_source/torch_compile_tutorial.py:239: TracerWarning:
Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
traced 1, 1: True
traced 1, 2: False
compile 1, 1: True
compile 1, 2: True
~~~~~~~~~~
TorchScript 指令碼模式可以處理依賴於資料的控制流,但可能需要重大的程式碼更改,並且在使用不受支援的 Python 時會引發錯誤。
在下面的示例中,我們忘記了 TorchScript 的型別註解,並且收到了一個 TorchScript 錯誤,因為引數 y(一個 int)的輸入型別與預設引數型別 torch.Tensor 不匹配。相比之下,torch.compile 在不需要任何型別註解的情況下工作。
import traceback as tb
torch._logging.set_logs(graph_code=True)
def f2(x, y):
return x + y
inp1 = torch.randn(5, 5)
inp2 = 3
script_f2 = torch.jit.script(f2)
try:
script_f2(inp1, inp2)
except:
tb.print_exc()
compile_f2 = torch.compile(f2)
print("compile 2:", test_fns(f2, compile_f2, (inp1, inp2)))
print("~" * 10)
Traceback (most recent call last):
File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 288, in <module>
script_f2(inp1, inp2)
RuntimeError: f2() Expected a value of type 'Tensor (inferred)' for argument 'y' but instead found type 'int'.
Inferred 'y' to be of type 'Tensor' because it was not annotated with an explicit type.
Position: 1
Value: 3
Declaration: f2(Tensor x, Tensor y) -> Tensor
Cast error details: Unable to cast 3 to Tensor
TRACED GRAPH
===== __compiled_fn_18_60f88fab_6a3d_4dcc_a2ea_16a1899bfb1f =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[5, 5][5, 1]cpu"):
l_x_ = L_x_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:280 in f2, code: return x + y
add: "f32[5, 5][5, 1]cpu" = l_x_ + 3; l_x_ = None
return (add,)
compile 2: True
~~~~~~~~~~
圖中斷 (Graph Breaks)#
圖中斷是 torch.compile 中最基本概念之一。它透過中斷編譯、執行不受支援的程式碼,然後恢復編譯,從而使 torch.compile 能夠處理任意 Python 程式碼。術語“圖中斷”來自於 torch.compile 嘗試捕獲和最佳化 PyTorch 操作圖的事實。當遇到不受支援的 Python 程式碼時,該圖必須被“中斷”。圖中斷會導致最佳化機會的丟失,這仍然可能是不受歡迎的,但這比靜默不正確或硬崩潰要好。
讓我們透過一個依賴於資料的控制流示例來更好地瞭解圖中斷的工作原理。
def bar(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
opt_bar = torch.compile(bar)
inp1 = torch.ones(10)
inp2 = torch.ones(10)
opt_bar(inp1, inp2)
opt_bar(inp1, -inp2)
TRACED GRAPH
===== __compiled_fn_20_d5309909_d209_4382_9b82_0ba74ced4ca8 =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
l_a_ = L_a_
l_b_ = L_b_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:312 in bar, code: x = a / (torch.abs(a) + 1)
abs_1: "f32[10][1]cpu" = torch.abs(l_a_)
add: "f32[10][1]cpu" = abs_1 + 1; abs_1 = None
x: "f32[10][1]cpu" = l_a_ / add; l_a_ = add = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:313 in bar, code: if b.sum() < 0:
sum_1: "f32[][]cpu" = l_b_.sum(); l_b_ = None
lt: "b8[][]cpu" = sum_1 < 0; sum_1 = None
return (lt, x)
TRACED GRAPH
===== __compiled_fn_24_24e667b5_a8e5_442d_b94a_a878f1114d23 =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
l_x_ = L_x_
l_b_ = L_b_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b
mul: "f32[10][1]cpu" = l_x_ * l_b_; l_x_ = l_b_ = None
return (mul,)
TRACED GRAPH
===== __compiled_fn_26_d1830df0_39a5_4379_96f3_af6c112110cd =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_b_: "f32[10][1]cpu", L_x_: "f32[10][1]cpu"):
l_b_ = L_b_
l_x_ = L_x_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:314 in torch_dynamo_resume_in_bar_at_313, code: b = b * -1
b: "f32[10][1]cpu" = l_b_ * -1; l_b_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b
mul_1: "f32[10][1]cpu" = l_x_ * b; l_x_ = b = None
return (mul_1,)
tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
0.5000])
第一次執行 bar 時,我們看到 torch.compile 跟蹤了 2 個圖,對應於以下程式碼(請注意 b.sum() < 0 為 False)
x = a / (torch.abs(a) + 1); b.sum()return x * b
第二次執行 bar 時,我們走到了 if 語句的另一條分支,並得到了 1 個跟蹤的圖,對應於程式碼 b = b * -1; return x * b。第二次執行時我們沒有看到 x = a / (torch.abs(a) + 1) 的圖輸出,因為 torch.compile 從第一次執行快取了該圖並重用了它。
讓我們透過示例來研究 TorchDynamo 如何逐步執行 bar。如果 b.sum() < 0,則 TorchDynamo 將執行圖 1,讓 Python 確定條件的結果,然後執行圖 2。另一方面,如果 not b.sum() < 0,則 TorchDynamo 將執行圖 1,讓 Python 確定條件的結果,然後執行圖 3。
我們可以透過使用 torch._logging.set_logs(graph_breaks=True) 來檢視所有圖中斷。
TRACED GRAPH
===== __compiled_fn_28_e75c1c8c_4795_4a16_8d6f_90d489a9e78e =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
l_a_ = L_a_
l_b_ = L_b_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:312 in bar, code: x = a / (torch.abs(a) + 1)
abs_1: "f32[10][1]cpu" = torch.abs(l_a_)
add: "f32[10][1]cpu" = abs_1 + 1; abs_1 = None
x: "f32[10][1]cpu" = l_a_ / add; l_a_ = add = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:313 in bar, code: if b.sum() < 0:
sum_1: "f32[][]cpu" = l_b_.sum(); l_b_ = None
lt: "b8[][]cpu" = sum_1 < 0; sum_1 = None
return (lt, x)
TRACED GRAPH
===== __compiled_fn_32_e26b0760_f8cc_414d_a852_6092ac007ca7 =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
l_x_ = L_x_
l_b_ = L_b_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b
mul: "f32[10][1]cpu" = l_x_ * l_b_; l_x_ = l_b_ = None
return (mul,)
TRACED GRAPH
===== __compiled_fn_34_2b406644_b833_40a0_96ec_c1f387d13c7f =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_b_: "f32[10][1]cpu", L_x_: "f32[10][1]cpu"):
l_b_ = L_b_
l_x_ = L_x_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:314 in torch_dynamo_resume_in_bar_at_313, code: b = b * -1
b: "f32[10][1]cpu" = l_b_ * -1; l_b_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b
mul_1: "f32[10][1]cpu" = l_x_ * b; l_x_ = b = None
return (mul_1,)
tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
0.5000])
為了最大化加速效果,應該限制圖中斷。我們可以透過使用 fullgraph=True 來強制 TorchDynamo 在遇到第一個圖中斷時引發錯誤。
# Reset to clear the torch.compile cache
torch._dynamo.reset()
opt_bar_fullgraph = torch.compile(bar, fullgraph=True)
try:
opt_bar_fullgraph(torch.randn(10), torch.randn(10))
except:
tb.print_exc()
Traceback (most recent call last):
File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 360, in <module>
opt_bar_fullgraph(torch.randn(10), torch.randn(10))
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 841, in compile_wrapper
raise e.with_traceback(None) from e.__cause__ # User compiler error
torch._dynamo.exc.Unsupported: Data-dependent branching
Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
Hint: Use `torch.cond` to express dynamic control flow.
Developer debug context: attempted to jump with TensorVariable()
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0170.html
from user code:
File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 313, in bar
if b.sum() < 0:
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
在我們上面的示例中,我們可以透過將 if 語句替換為 torch.cond 來解決此圖中斷問題。
from functorch.experimental.control_flow import cond
@torch.compile(fullgraph=True)
def bar_fixed(a, b):
x = a / (torch.abs(a) + 1)
def true_branch(y):
return y * -1
def false_branch(y):
# NOTE: torch.cond doesn't allow aliased outputs
return y.clone()
x = cond(b.sum() < 0, true_branch, false_branch, (b,))
return x * b
bar_fixed(inp1, inp2)
bar_fixed(inp1, -inp2)
TRACED GRAPH
===== __compiled_fn_37_6c5f108a_d951_495b_a538_024359c8fc5a =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
l_a_ = L_a_
l_b_ = L_b_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:373 in bar_fixed, code: x = a / (torch.abs(a) + 1)
abs_1: "f32[10][1]cpu" = torch.abs(l_a_)
add: "f32[10][1]cpu" = abs_1 + 1; abs_1 = None
x: "f32[10][1]cpu" = l_a_ / add; l_a_ = add = x = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:382 in bar_fixed, code: x = cond(b.sum() < 0, true_branch, false_branch, (b,))
sum_1: "f32[][]cpu" = l_b_.sum()
lt: "b8[][]cpu" = sum_1 < 0; sum_1 = None
# File: /usr/local/lib/python3.10/dist-packages/torch/_higher_order_ops/cond.py:186 in cond, code: return cond_op(pred, true_fn, false_fn, operands)
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(lt, cond_true_0, cond_false_0, (l_b_,)); lt = cond_true_0 = cond_false_0 = None
x_1: "f32[10][1]cpu" = cond[0]; cond = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:383 in bar_fixed, code: return x * b
mul: "f32[10][1]cpu" = x_1 * l_b_; x_1 = l_b_ = None
return (mul,)
class cond_true_0(torch.nn.Module):
def forward(self, l_b_: "f32[10][1]cpu"):
l_b__1 = l_b_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:376 in true_branch, code: return y * -1
mul: "f32[10][1]cpu" = l_b__1 * -1; l_b__1 = None
return (mul,)
class cond_false_0(torch.nn.Module):
def forward(self, l_b_: "f32[10][1]cpu"):
l_b__1 = l_b_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:380 in false_branch, code: return y.clone()
clone: "f32[10][1]cpu" = l_b__1.clone(); l_b__1 = None
return (clone,)
tensor([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.])
為了序列化圖或在不同的(例如,無 Python 的)環境中執行圖,請考慮改用 torch.export(從 PyTorch 2.1+ 開始)。一個重要的限制是 torch.export 不支援圖中斷。請參閱 torch.export 教程 以獲取有關 torch.export 的更多詳細資訊。
請檢視我們在 torch.compile 程式設計模型中關於圖中斷的部分,以獲取有關如何解決圖中斷的技巧。
故障排除#
torch.compile 未能加速您的模型?編譯時間過長?您的程式碼是否過度重新編譯?您在處理圖中斷方面遇到困難?您是否正在尋找如何最好地使用 torch.compile 的技巧?或者您只是想更多地瞭解 torch.compile 的內部工作原理?
請檢視 torch.compile 程式設計模型。
結論#
在本教程中,我們透過介紹基本用法、演示與 eager 模式相比的加速效果、與 TorchScript 進行比較以及簡要描述圖中斷,介紹了 torch.compile。
有關真實模型的端到端示例,請檢視我們的 torch.compile 端到端教程。
要排查問題並更深入地瞭解如何將 torch.compile 應用於您的程式碼,請檢視 torch.compile 程式設計模型。
希望您會嘗試 torch.compile!
指令碼總執行時間: (0 分鐘 16.527 秒)