torch.compile 故障排除#
創建於:2022 年 11 月 28 日 | 最後更新於:2025 年 8 月 14 日
您正嘗試在 PyTorch 模型上使用 torch.compile 來提高其效能,但效果不如預期。也許效能沒有提升,出現了崩潰,或者編譯時間過長。本文提供了技巧、解決方法和除錯工具,以幫助您克服這些挑戰。
內容
設定預期#
torch.compile 被設計為通用 PyTorch 編譯器。與之前的編譯器解決方案 TorchScript 不同,torch.compile 需要更少的程式碼修改,這意味著模型通常不需要從頭開始重寫。它還能更平穩地處理不支援的程式碼——不支援的程式碼會導致最佳化機會的丟失,而不是崩潰。
理想情況下,人們可以直接將 torch.compile 應用於任何 PyTorch 模型並享受自動加速。然而,在現實中,程式碼的複雜性可能導致以下三種情況之一:
torch.compile無縫工作,提供加速。需要一些程式碼修改。
torch.compile不會崩潰或花費太多時間,但您可能看不到顯著的效能提升。需要對程式碼進行大量更改。
我們預計大多數程式碼將屬於情況 (1) 和 (2)。本文件提供了按參與度級別排列的技巧,以幫助解決情況 (2) 中的程式碼問題。
術語#
以下術語與 torch.compile 問題故障排除相關。
圖中斷#
torch.compile 會跟蹤您的程式碼,並嘗試將您的 PyTorch 程式碼捕獲到一個 PyTorch 運算元的計算圖中(FX 圖)。然而,這並非總是可能的。當遇到無法跟蹤的程式碼時,就會發生“圖中斷”。圖中斷包括編譯到目前為止已確定的 FX 圖,執行不支援的程式碼,然後用新的 FX 圖從不支援的程式碼處恢復跟蹤。由於計算圖被打破,我們失去了最佳化機會,因此模型程式碼應儘可能避免圖中斷。圖中斷髮生在以下情況:
資料依賴的 if 語句
許多 Python 內建函式
C 函式
下面是一個由於 Python 內建庫中的 copy.deepcopy 函式引起的圖中斷示例(確切輸出可能有所不同)。
import torch
@torch.compile
def fn(x):
x = x + 1
with open("test.txt", "r") as f:
return x + len(f.read())
fn(torch.ones(3, 3))
$TORCH_LOGS="graph_breaks" python playground.py
Graph break in user code at /data/users/williamwen/pytorch/playground.py:7
Reason: Unsupported: builtin: open [<class 'torch._dynamo.variables.constant.ConstantVariable'>, <class 'torch._dynamo.variables.constant.ConstantVariable'>] False
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 7, in fn
with open("test.txt", "r") as f:
Traceback (most recent call last):
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 635, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2414, in CALL
self._call(inst)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2408, in _call
self.call_function(fn, args, kwargs)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 962, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/builtin.py", line 997, in call_function
return handler(tx, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/builtin.py", line 831, in <lambda>
return lambda *args: unimplemented(error_msg)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/exc.py", line 313, in unimplemented
raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: builtin: open [<class 'torch._dynamo.variables.constant.ConstantVariable'>, <class 'torch._dynamo.variables.constant.ConstantVariable'>] False
Guard#
torch.compile 在跟蹤程式碼時會做出一些關於執行時值的假設。在跟蹤期間,我們會生成“guards”,這些是這些假設的執行時檢查。Guard 在編譯函式的未來呼叫中執行,以確定我們是否可以重用先前編譯的程式碼。執行時檢查的例子包括常量值、型別和物件 ID。
下面是一個生成的 Guard 的示例。TENSOR_MATCH Guard 檢查輸入的型別、裝置、dtype、形狀等。
import torch
@torch.compile
def fn(x):
return x + 1
fn(torch.ones(3, 3))
$ TORCH_LOGS="guards" python playground.py
GUARDS:
TREE_GUARD_MANAGER:
+- RootGuardManager
| +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:471 in init_ambient_guards
| +- GLOBAL_STATE: ___check_global_state()
| +- TORCH_FUNCTION_MODE_STACK: ___check_torch_function_mode_stack()
| +- GuardManager: source=L['x'], accessed_by=DictGetItemGuardAccessor(x)
| | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU), torch.float32, device=None, requires_grad=False, size=[3, 3], stride=[3, 1]) # return x + 1 # playground.py:6 in fn
| | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False # return x + 1 # playground.py:6 in fn
重新編譯#
如果之前編譯的程式碼的所有例項的 Guard 都失敗了,那麼 torch.compile 必須“重新編譯”該函式,需要再次跟蹤原始程式碼。
在下面的示例中,需要重新編譯,因為檢查張量引數形狀的 Guard 失敗了。
import torch
@torch.compile
def fn(x):
return x + 1
fn(torch.ones(3, 3))
fn(torch.ones(4, 4))
$ TORCH_LOGS="recompiles" python playground.py
Recompiling function fn in /data/users/williamwen/pytorch/playground.py:3
triggered by the following guard failure(s):
- 0/0: tensor 'L['x']' size mismatch at index 0. expected 3, actual 4
動態形狀#
torch.compile 最初假設張量形狀是靜態/不變的,並基於這些假設建立 Guard。透過使用“動態形狀”,我們可以讓 torch.compile 生成可以接受不同形狀張量輸入的編譯程式碼——我們避免了每次形狀不同時都重新編譯。預設情況下,自動動態形狀是啟用的 torch.compile(dynamic=None) ——如果由於形狀不匹配導致編譯失敗,則會嘗試使用動態形狀進行重新編譯。動態形狀也可以完全啟用 dynamic=True 或停用 dynamic=False。
下面,我們啟用了動態形狀,並注意到我們不再需要重新編譯。
import torch
@torch.compile(dynamic=True)
def fn(x):
return x + 1
fn(torch.ones(3, 3))
fn(torch.ones(4, 4))
$ TORCH_LOGS="dynamic,recompiles" python playground.py
create_symbol s0 = 3 for L['x'].size()[0] [2, int_oo] at playground.py:5 in fn (_dynamo/variables/builder.py:2718 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0"
produce_guards
produce_guards
有關動態形狀的更多資訊,請參閱 動態形狀手冊。
日誌工具#
tlparse / TORCH_TRACE#
tlparse / TORCH_TRACE 是一對工具,它們生成如下所示的編譯報告: https://web.mit.edu/~ezyang/Public/bhack-20240609-tlparse/index.html。
收集跟蹤非常容易。要收集跟蹤,請使用以下命令執行您的重現命令:
TORCH_TRACE="/tmp/tracedir" python foo.py
pip install tlparse
tlparse /tmp/tracedir
這種方法甚至適用於您正在執行分散式作業的情況,為每個 rank 提供一個跟蹤。它將在瀏覽器中開啟 HTML,類似於上面生成的 HTML。如果您正在為沒有獨立重現的複雜問題編寫 bug 報告,您仍然可以透過附加在 /tmp/tracedir 中生成的跟蹤日誌來極大地幫助 PyTorch 開發者。
警告
跟蹤日誌包含您的所有模型程式碼。如果您的模型是敏感的,請不要共享跟蹤日誌。跟蹤日誌不包含權重。
tlparse 的輸出主要面向 PyTorch 開發者,並且日誌格式易於上傳並在 GitHub 上共享。然而,作為非 PyTorch 開發者,您仍然可以從中提取有用的資訊。我們建議從報告中的內聯幫助文字開始,它解釋了其內容。以下是一些您可以從 tlparse 中獲得的見解:
透過檢視堆疊樹,瞭解了哪些模型程式碼被編譯了?如果您不熟悉正在編譯的程式碼庫,這一點尤其有用!
有多少圖中斷 / 不同的編譯區域?(每個獨立的編譯都是自己的顏色編碼塊,如 [0/0])。可能圖中斷的幀是淺綠色 [2/4]。如果有很多幀,那很可疑,這表明您發生了一些災難性的圖中斷,或者您的程式碼可能不適合
torch.compile。我重新編譯了某個幀多少次?反覆重新編譯的幀會顯示為: [10/0] [10/1] [10/2] - 如果某項被反覆重新編譯,那非常可疑,值得深入研究,即使它不是您問題的根本原因。
是否存在編譯錯誤?出現錯誤的幀將顯示為 [0/1]。
為給定幀生成了哪些中間編譯器產品?例如,您可以檢視高階生成的 FX 圖或生成的 Triton 程式碼。
某個幀是否有相關資訊?您可以在
compilation_metrics中找到它們。
TORCH_LOGS#
您可以使用 TORCH_LOGS 環境變數選擇性地啟用 torch.compile 堆疊的部分以進行日誌記錄。TORCH_LOGS 實際上是 tlparse 的日誌來源。TORCH_LOGS 環境變數的格式如下:
TORCH_LOGS="<option1>,<option2>,..." python foo.py
有用的高階選項包括:
graph_breaks:記錄使用者程式碼中圖中斷的位置及其原因。guards:記錄生成的 Guard。recompiles:記錄哪個函式被重新編譯以及導致重新編譯的失敗 Guard。dynamic:記錄與動態形狀相關的日誌。
此外,您還可以使用 torch._logging.set_logs 以程式設計方式設定日誌選項。
import logging
torch._logging.set_logs(graph_breaks=True)
...
TORCH_LOGS 的更多選項請參閱 TORCH_LOGS 選項摘要。有關完整選項列表,請參閱 torch._logging 和 torch._logging.set_logs。
tlparse 與 TORCH_LOGS#
通常,我們建議在遇到問題時首先使用 tlparse。tlparse 非常適合除錯大型模型並獲得模型編譯方式的高階概覽。另一方面,TORCH_LOGS 更適合小型示例和精細的除錯細節,當您已經大致瞭解是哪個 torch.compile 元件引起問題時。
簡單的解決方法#
在這裡,我們描述了一些針對 torch.compile 問題的解決方法,這些問題涉及小的程式碼修改或更改一些 torch.compile 設定。
在哪裡應用 torch.compile?#
我們建議將 torch.compile 應用於最高級別的、不會導致過多問題的函式。通常,這是您的訓練或評估步驟(帶最佳化器但沒有迴圈),您的頂層 nn.Module,或一些子 nn.Module。 torch.compile 特別不擅長處理分散式包裝模組(如 DDP 或 FSDP),因此請考慮將 torch.compile 應用於傳遞給包裝器的內部模組。
# inference
model = ...
opt_model = torch.compile(model)
for _ in range(N_ITERS):
inp = ...
out = opt_model(inp)
# training
model = ...
opt = torch.optim.Adam(model.parameters())
@torch.compile
def train(mod, data):
opt.zero_grad(True)
pred = mod(data[0])
loss = torch.nn.CrossEntropyLoss()(pred, data[1])
loss.backward()
opt.step()
for _ in range(N_ITERS):
inp = ...
train(model, inp)
# DistributedDataParallel
model = ...
opt_model = torch.compile(model)
model_ddp = DistributedDataParallel(opt_model, ...)
for _ in range(N_ITERS):
inp = ...
out = model_ddp(inp)
停用和抑制錯誤#
對於某些模型架構,模型的一部分可能特別難以編譯——要麼有很多圖中斷,要麼發生崩潰。您可能希望顯式停用這些有問題的模型部分,以便您可以將 torch.compile 應用於可工作的部件。您可以使用 `@torch.compiler.disable` 裝飾器來實現此目的。當 torch.compile 嘗試呼叫已停用的函式時,它會中斷圖並跳過停用函式的跟蹤,然後在不支援的函式呼叫後用新的 FX 圖恢復跟蹤。預設情況下,從被停用函式發出的所有遞迴呼叫也都被停用。使用 `recursive=False` 選項可以允許遞迴呼叫的編譯。
def bad1_inner(...):
# skipped
@torch.compiler.disable
def bad1_outer(...):
# skipped
bad1_inner(...)
def bad2_inner(...)
# traced
@torch.compiler.disable(recursive=False)
def bad2_outer(...):
# skipped
bad2_inner(...)
@torch.compile
def fn(...):
# graph break
bad1_outer(...)
...
# graph break
bad2_outer(...)
例如,我們使用 torch.compiler.disable 來停用推薦模型中稀疏架構上的 torch.compile,因為稀疏架構很難編譯。預處理和日誌記錄函式是導致大量圖中斷且從編譯中獲益不大的函式的其他示例。
如果您遇到編譯器崩潰並希望繼續進行,可以設定 torch._dynamo.config.suppress_errors = True。當編譯器崩潰時,我們將跳過函式的跟蹤並稍後重試。這不是最佳實踐——最好最終手動新增所需的停用註解。
解決圖中斷#
為了最大化最佳化機會,減少圖中斷的數量很重要。請記住,您可以使用 tlparse 或 TORCH_LOGS="graph_breaks" 來檢視正在發生的圖中斷。通常,圖中斷是由以下原因之一引起的:
您正在嘗試執行一項根本無法跟蹤的操作,例如資料依賴的控制流。
您正在嘗試執行一項尚未支援的操作。例如,我們目前對跟蹤使用內建 Python
inspect模組的程式碼的支援有限。您的程式碼中有錯誤。例如,您可能嘗試使用錯誤的引數數量呼叫函式。
圖中斷日誌會告訴您使用者程式碼的位置和圖中斷的原因。不幸的是,許多圖中斷需要對 Dynamo 有更深入的瞭解才能處理。甚至可能難以確定您的圖中斷的真正原因屬於這三個原因中的哪一個。我們正在努力使圖中斷訊息更具可操作性。
此外,圖中斷對丟失最佳化機會的影響也不同。例如,發生在模型 forward 中間的圖中斷可能比發生在 forward 開始處的預處理部分的圖中斷產生更負面的影響。因此,阻止每一個中斷並非至關重要,而是要阻止那些導致效能顯著下降的中斷。
如果圖中斷訊息沒有建議任何操作,您懷疑圖中斷的原因是 (2),並且您認為圖中斷導致效能下降,那麼請將該圖中斷報告為一個問題。如果一個函式有很多圖中斷,請考慮停用該函式的編譯,因為圖中斷的開銷可能會變得過高。
下面是一些常見的圖中斷及其解決方法。
資料依賴的操作#
torch.compile 在資料依賴的操作上會發生圖中斷,例如資料依賴的控制流(if 語句、帶有張量的迴圈)和直接的張量資料訪問(.item、.data_ptr)。
import torch
@torch.compile
def fn(x):
y = x.sum()
if y > 0:
return x + y.item()
return x - y.item()
fn(torch.ones(3, 3))
$ TORCH_LOGS="graph_breaks" python playground.py
Graph break in user code at /data/users/williamwen/pytorch/playground.py:6
Reason: Data-dependent jump
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 6, in fn
if y > 0:
Graph break in user code at /data/users/williamwen/pytorch/playground.py:7
Reason: Unsupported: Tensor.item
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 7, in torch_dynamo_resume_in_fn_at_6
return x + y.item()
Traceback (most recent call last):
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 616, in wrapper
return inner_fn(self, inst)
^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2288, in CALL
self._call(inst)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 2282, in _call
self.call_function(fn, args, kwargs)
File "/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py", line 838, in call_function
self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/misc.py", line 1038, in call_function
return self.obj.call_method(tx, self.name, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/tensor.py", line 527, in call_method
result = handler_method(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/users/williamwen/pytorch/torch/_dynamo/variables/tensor.py", line 773, in method_item
unimplemented("Tensor.item")
File "/data/users/williamwen/pytorch/torch/_dynamo/exc.py", line 304, in unimplemented
raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: Tensor.item
這些圖中斷的通用解決方法是避免執行資料依賴的操作。一些具體的解決方法是:
如果您的控制流實際上不依賴於資料值,請考慮修改您的程式碼以對常量執行控制流。
# old
x = torch.randn(3, 3)
@torch.compile
def fn(y):
if x.sum() > 0:
return y + x
else:
return y - x
# new
x = torch.randn(3, 3)
cond = (x.sum() > 0).item()
@torch.compile
def fn(y):
if cond:
return y + x
else:
return y - x
在資料依賴的控制流中使用像
torch.cond(https://pytorch.com.tw/docs/stable/cond.html) 這樣的高階運算元。
# old
@torch.compile
def fn(x):
if x.sum() > 0:
return x + 1
return x - 1
# new
@torch.compile
def fn(x):
return torch.cond(
x.sum() > 0,
lambda x: x + 1,
lambda x: x - 1,
(x,),
)
如果您有
.item()呼叫,請嘗試torch._dynamo.config.capture_scalar_outputs = True或TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1。將函式中有問題的部分包裝在自定義運算元中。
自定義運算元#
如果您有 torch.compile 難以跟蹤的程式碼,無論是由於缺少支援還是根本不相容,您都可以考慮將有問題的程式碼包裝在自定義運算元中。
自定義運算元需要額外的一些工作才能使其與 torch.compile 相容。有關更多詳細資訊,請參閱 https://pytorch.com.tw/tutorials/advanced/custom_ops_landing_page.html。
列印#
列印/日誌/發出警告將導致圖中斷。如果您有一個函式進行許多日誌記錄呼叫,例如,一個記錄有關訓練迭代資料的函式,請考慮在該函式上應用 torch.compiler.disable。
或者,您可以嘗試使用 torch._dynamo.config.reorderable_logging_functions。此配置用於重新排序日誌函式,以便它們在跟蹤函式的末尾被呼叫,從而避免圖中斷。但是,如果發生突變,日誌內容可能會有所不同。
import torch
torch._dynamo.config.reorderable_logging_functions.add(print)
@torch.compile
def fn(x):
x += 1
print("log!")
return torch.sin(x)
fn(torch.ones(3, 3))
$ TORCH_LOGS="graph_breaks" python playground.py
log!
程式碼錯誤#
您的程式碼可能不正確,或者遇到了來自 torch.compile 之外的錯誤。在下面的程式碼中,我們在 torch.sin 呼叫中犯了一個拼寫錯誤,提供了一個額外的引數。
import torch
@torch.compile
def fn(x):
y = torch.sin(x, x)
return y
fn(torch.ones(3, 3))
$ TORCH_LOGS="graph_breaks" python playground.py
Graph break in user code at /data/users/williamwen/pytorch/playground.py:5
Reason: Unsupported: TypeError <built-in method sin of type object at 0x7fd6fd764600>: sin() takes 1 positional argument but 2 were given
User code traceback:
File "/data/users/williamwen/pytorch/playground.py", line 5, in fn
y = torch.sin(x, x)
...
從日誌中很難判斷錯誤是由您的程式碼引起的,還是由 torch.compile 的 bug 引起的。為了區分,我們建議嘗試在沒有 torch.compile 的情況下執行您的程式碼,看看是否仍然出現錯誤。
處理重新編譯#
您可以使用 tlparse 或 TORCH_LOGS=recompiles 檢視重新編譯及其原因。
是否啟用了動態形狀?#
由於形狀不匹配而導致的重新編譯形式為:
tensor 'L['x']' size mismatch at index 0. expected 3, actual 4
確保 torch.compile 的 `dynamic` 選項未設定為 False。預設選項 dynamic=None 只會在第一次編譯後嘗試動態形狀。您可以設定 dynamic=True 以提前儘可能地進行動態編譯。
有關動態形狀的更多資訊,請參閱 動態形狀手冊。
更改快取大小限制#
函式可以被重新編譯的次數是有限制的,由 torch._dynamo.config.recompile_limit 和 torch._dynamo.config.accumulated_recompile_limit 確定。如果任一限制被超過,我們將不再嘗試重新編譯該函式,而是會以惰性模式執行該函式。 torch.compile 還會發出警告,其中包含受影響的函式以及達到了哪個限制。在下面的示例中,每次函式呼叫都會導致一次重新編譯嘗試。當我們達到快取大小限制(8)時,我們會停止嘗試重新編譯。
import torch
@torch.compile(dynamic=False)
def fn(x):
return x + 1
for i in range(1, 10):
fn(torch.ones(i))
$ python playground.py
torch._dynamo hit config.recompile_limit (8)
function: 'fn' (/data/users/williamwen/pytorch/playground.py:5)
last reason: 0/0: tensor 'L['x']' size mismatch at index 0. expected 1, actual 9
如果您知道重新編譯次數有一個合理的固定上限,您可以提高快取大小限制。如果重新編譯的成本超過了編譯的好處,那麼您可以考慮降低快取大小限制。
用張量包裝常量#
預設情況下,int / float 變數被視為常量並以此進行 Guard。在下面的示例中,每次函式呼叫都會導致一次重新編譯。
import torch
@torch.compile
def fn(x, c):
return x + c
for i in range(1, 10):
fn(torch.ones(i), 0.5 + i)
$ TORCH_LOGS="recompiles" python playground.py
Recompiling function fn in /data/users/williamwen/pytorch/playground.py:3
triggered by the following guard failure(s):
- 0/7: L['c'] == 8.5
- 0/6: L['c'] == 7.5
- 0/5: L['c'] == 6.5
- 0/4: L['c'] == 5.5
- 0/3: L['c'] == 4.5
- 0/2: L['c'] == 3.5
- 0/1: L['c'] == 2.5
- 0/0: L['c'] == 1.5
torch._dynamo hit config.recompile_limit (8)
function: 'fn' (/data/users/williamwen/pytorch/playground.py:3)
last reason: 0/0: L['c'] == 1.5
特別是對於 LR 排程器,使用常量進行初始化可能導致重新編譯。
import torch
mod = torch.nn.Linear(3, 3)
opt = torch.optim.Adam(mod.parameters(), lr=0.01)
sched = torch.optim.lr_scheduler.ExponentialLR(opt, 0.9)
@torch.compile
def fn(inp):
opt.zero_grad(True)
out = mod(inp).sum()
out.backward()
opt.step()
sched.step()
for i in range(1, 10):
fn(torch.ones(3, 3))
$ TORCH_LOGS="recompiles" python playground.py
Recompiling function step in /data/users/williamwen/pytorch/torch/optim/adam.py:189
triggered by the following guard failure(s):
- 3/7: L['self'].param_groups[0]['lr'] == 0.004782969000000002
- 3/6: L['self'].param_groups[0]['lr'] == 0.005314410000000002
- 3/5: L['self'].param_groups[0]['lr'] == 0.005904900000000002
- 3/4: L['self'].param_groups[0]['lr'] == 0.006561000000000002
- 3/3: L['self'].param_groups[0]['lr'] == 0.007290000000000001
- 3/2: L['self'].param_groups[0]['lr'] == 0.008100000000000001
- 3/1: L['self'].param_groups[0]['lr'] == 0.009000000000000001
- 3/0: L['self'].param_groups[0]['lr'] == 0.01
torch._dynamo hit config.recompile_limit (8)
function: 'step' (/data/users/williamwen/pytorch/torch/optim/adam.py:189)
last reason: 3/0: L['self'].param_groups[0]['lr'] == 0.01
在以上兩個示例中,我們可以將浮點變數包裝在張量中以防止重新編譯。
# first example
for i in range(1, 10):
fn(torch.ones(i), torch.tensor(0.5 + i))
# second example
opt = torch.optim.Adam(mod.parameters(), lr=torch.tensor(0.01))
sched = torch.optim.lr_scheduler.ExponentialLR(opt, torch.tensor(0.9))
報告問題#
如果上述解決方法不足以使 torch.compile 工作,那麼您應該考慮將問題報告給 PyTorch。但是,您可以做一些事情來使我們的生活變得更輕鬆。
消融#
使用 torch.compile 的 `backend=` 選項檢查 torch.compile 堆疊的哪個元件導致了問題。特別是,嘗試:
torch.compile(fn, backend="eager"),它只執行torch.compile的圖捕獲元件 TorchDynamo。torch.compile(fn, backend="aot_eager"),它執行 TorchDynamo 和 AOTAutograd,後者在編譯期間還會生成後向圖。torch.compile(fn, backend="aot_eager_decomp_partition"),它執行 TorchDynamo 和 AOTAutograd,並進行運算元分解/分割槽。torch.compile(fn, backend="inductor"),它執行 TorchDynamo、AOTAutograd 和 TorchInductor,這是生成編譯核心的後端 ML 編譯器。
如果您僅在 Inductor 後端出現問題,您還可以測試各種 Inductor 模式:
torch.compile(fn, backend="inductor", mode="default")torch.compile(fn, backend="inductor", mode="reduce-overhead")torch.compile(fn, backend="inductor", mode="max-autotune")
您還可以檢查動態形狀是否導致任何後端出現問題:
torch.compile(fn, dynamic=True)(始終使用動態形狀)torch.compile(fn, dynamic=False)(從不使用動態形狀)torch.compile(fn, dynamic=None)(自動動態形狀)
二分查詢#
您是否嘗試過最新的 nightly 版本?過去有效的東西現在是否不再有效?您能否二分查詢以確定問題的首次出現版本?二分查詢對於效能、準確性或編譯時間迴歸尤其有用,因為這些問題很難立即確定其來源。
建立重現器#
建立重現器需要大量工作,如果您沒有時間完成它,那也是完全可以接受的。然而,如果您是一位熟悉 torch.compile 內部結構但不熟悉的有積極性的使用者,建立獨立的重現器對我們修復 bug 的能力有巨大的影響。沒有重現器,您的 bug 報告必須包含足夠的資訊,以便我們確定問題的根源並從頭開始編寫重現器。
以下是可用重現器的列表,按偏好程度從高到低排序:
獨立的、小的重現器: 一個沒有外部依賴項的指令碼,程式碼行數少於 100 行,執行後能重現問題。
獨立的、大的重現器: 即使程式碼很大,獨立性也是一個巨大的優勢!
具有可管理依賴項的非獨立重現器: 例如,如果您可以透過執行一個指令碼並在 `pip install transformers` 後進行重現,那是可以管理的。我們很可能能夠執行它並進行調查。
需要大量設定的非獨立重現器: 這可能涉及下載資料集、多個環境設定步驟或需要 Docker 映象的特定系統庫版本。設定越複雜,我們重現環境就越困難。
注意
Docker simplifies setup but complicates changes to the environment, so it's not a perfect solution, though we'll use it if necessary.
在某種程度上,一個可以在單個程序中執行的重現器比需要多程序訓練的重現器更好(但同樣,如果您只有一個多程序重現器,我們會接受!)。
此外,以下是可以在您的問題中檢查的方面列表,您可以在重現器中嘗試重現這些方面:
Autograd。是否有 `requires_grad=True` 的張量輸入?是否在輸出上呼叫了 `backward()`?
動態形狀。您是否設定了 `dynamic=True`?或者您是否多次使用不同形狀運行了測試程式碼?
自定義運算元。在實際工作流程中是否涉及自定義運算元?您能否使用 Python 自定義運算元 API 重現其某些重要特徵?
配置。您是否設定了所有相同的配置?這包括 `torch._dynamo.config` 和 `torch._inductor.config` 設定,以及 `torch.compile` 的引數,如 `backend` / `mode`。
上下文管理器。您是否重現了任何活動的上下文管理器?這可能包括 `torch.no_grad`、自動混合精度、`TorchFunctionMode` / `TorchDispatchMode`、啟用檢查點、編譯的 autograd 等。
張量子類。是否涉及張量子類?
Minifier#
Minifier 是一個早期的 torch.compile 工具,它接受一個在嘗試執行或編譯時崩潰的 FX 圖,找到一個同樣崩潰的子圖,並輸出執行該子圖操作的程式碼。本質上,minifier 找到了某種 torch.compile 相關崩潰的最小重現。這假設我們能夠成功跟蹤程式碼。
不幸的是,如今大多數時候,minifier 的效果不佳,可能需要其他方法。這可能是因為以這種方式自動重現的 bug 通常更容易修復並且已經得到解決,剩下更復雜且不易重現的問題。然而,嘗試使用 minifier 是很直接的,所以即使它可能不成功,也值得一試。
操作 minifier 的說明可以在此處找到。如果編譯器崩潰,您可以設定 TORCHDYNAMO_REPRO_AFTER="dynamo" 或 TORCHDYNAMO_REPRO_AFTER="aot"。aot 選項更可能成功,儘管它可能無法識別 AOTAutograd 問題。這將生成 `repro.py` 檔案,可能有助於診斷問題。對於準確性相關的問題,請考慮設定 TORCHDYNAMO_REPRO_LEVEL=4。請注意,這可能並不總是能成功識別有問題的子圖。
深入除錯#
本節提供了獨立除錯 torch.compile 問題或深入瞭解 torch.compile 堆疊的工具和技術。這些方法比上面介紹的方法更復雜,並且是 PyTorch 開發者經常用來除錯實際 torch.compile 問題的。
下面是堆疊的高階概述:

堆疊由三個主要元件組成:TorchDynamo、AOTAutograd 和 Inductor。我們的除錯策略包括首先識別錯誤發生的元件,然後單獨除錯該元件。要確定負責該問題的元件,請參閱上面“報告問題”下的“消融”部分。有關除錯特定元件的指導,請參閱以下各節。
TorchDynamo#
記錄 Dynamo 正在跟蹤的內容#
`TORCH_LOGS=trace_bytecode` 選項使您能夠檢視 Dynamo 正在跟蹤的確切位元組碼指令,以及 Python 直譯器堆疊的符號表示。在遇到圖中斷或崩潰時,建議檢查最後幾個被跟蹤的位元組碼指令。
您還可以使用 `TORCH_LOGS=trace_source` 來檢視 Dynamo 正在跟蹤的原始碼行。這可以與 `trace_bytecode` 結合使用,以檢視每個被跟蹤的位元組碼指令對應的原始碼行。
最後,您可以使用 `TORCH_LOGS=graph_code` 來查看錶示 Dynamo 跟蹤的 FX 圖的 Python 程式碼。您可以檢視此程式碼以仔細檢查正在跟蹤的正確運算元。
import torch
def g(x, y):
return x + y
@torch.compile(backend="eager")
def f(x):
x = torch.sin(x)
x = g(x, x)
return x
f(torch.ones(3, 3))
$ TORCH_LOGS="trace_bytecode,trace_source,graph_code" python playground.py
TRACE starts_line /data/users/williamwen/pytorch/playground.py:6 in f ()
@torch.compile(backend="eager")
TRACE RESUME 0 []
TRACE starts_line /data/users/williamwen/pytorch/playground.py:8 in f (f)
x = torch.sin(x)
TRACE LOAD_GLOBAL torch []
TRACE LOAD_ATTR sin [NullVariable(), PythonModuleVariable(<module 'torch' from '/data/users/williamwen/pytorch/torch/__init__.py'>)]
TRACE LOAD_FAST x [NullVariable(), TorchInGraphFunctionVariable(<built-in method sin of type object at 0x7f00f6964600>)]
TRACE CALL 1 [NullVariable(), TorchInGraphFunctionVariable(<built-in method sin of type object at 0x7f00f6964600>), LazyVariableTracker()]
TRACE STORE_FAST x [TensorVariable()]
TRACE starts_line /data/users/williamwen/pytorch/playground.py:9 in f (f)
x = g(x, x)
TRACE LOAD_GLOBAL g []
TRACE LOAD_FAST x [NullVariable(), UserFunctionVariable()]
TRACE LOAD_FAST x [NullVariable(), UserFunctionVariable(), TensorVariable()]
TRACE CALL 2 [NullVariable(), UserFunctionVariable(), TensorVariable(), TensorVariable()]
TRACE starts_line /data/users/williamwen/pytorch/playground.py:3 in g (g) (inline depth: 1)
def g(x, y):
TRACE RESUME 0 []
TRACE starts_line /data/users/williamwen/pytorch/playground.py:4 in g (g) (inline depth: 1)
return x + y
TRACE LOAD_FAST x []
TRACE LOAD_FAST y [TensorVariable()]
TRACE BINARY_OP 0 [TensorVariable(), TensorVariable()]
TRACE RETURN_VALUE None [TensorVariable()]
TRACE STORE_FAST x [TensorVariable()]
TRACE starts_line /data/users/williamwen/pytorch/playground.py:10 in f (f)
return x
TRACE LOAD_FAST x []
TRACE RETURN_VALUE None [TensorVariable()]
TRACED GRAPH
===== __compiled_fn_1 =====
/data/users/williamwen/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3, 3][3, 1]cpu"):
l_x_ = L_x_
# File: /data/users/williamwen/pytorch/playground.py:8 in f, code: x = torch.sin(x)
x: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_); l_x_ = None
# File: /data/users/williamwen/pytorch/playground.py:4 in g, code: return x + y
x_1: "f32[3, 3][3, 1]cpu" = x + x; x = None
return (x_1,)
設定 Dynamo 跟蹤的斷點#
有時在 Dynamo/使用者程式碼中插入斷點有助於檢視 Dynamo 在跟蹤使用者程式碼時的狀態。不幸的是,以常規 Python 方式插入斷點會導致 TorchDynamo 中的圖中斷,因此我們無法在計劃設定斷點的地方檢視 Dynamo 的狀態。
設定斷點的第一種方法是在 Dynamo 原始碼中插入。三個推薦的斷點位置是:
在
torch/_dynamo/symbolic_convert.py中,在命名與有問題的位元組碼指令相同的函式處設定斷點,例如 `def CALL_FUNCTION` 和 `def STORE_ATTR`。您可以根據輸入有條件地設定斷點,例如,指令的 `argval`,或者位於堆疊頂部的物件的名稱,因為某些位元組碼操作碼經常被使用。在圖中斷或錯誤起源處設定斷點。通常,圖中斷是從對 `unimplemented(...)` 的呼叫發出的。
在
torch/_dynamo/variables/builder.py, function:_wrap中設定斷點。您很可能需要根據輸入有條件地設定斷點。此函式決定如何符號化地表示給定值。如果您懷疑某個值被錯誤地表示,請考慮在此處設定斷點。
插入斷點的第二種方法是使用 torch._dynamo.comptime.comptime.breakpoint。
from torch._dynamo.comptime import comptime
@torch.compile
def f(...):
...
comptime.breakpoint()
...
comptime 斷點很方便,因為它允許您在正在跟蹤的使用者程式碼的特定位置檢查 Dynamo 狀態。它不需要您在 Dynamo 原始碼中插入斷點或根據變數有條件地設定斷點。
當觸發 comptime 斷點時,您可以執行以下操作:
ctx.print_bt()列印使用者堆疊跟蹤。ctx.print_locals()列印所有當前區域性變數。ctx.print_graph()列印當前跟蹤的圖。ctx.disas()列印當前跟蹤函式的位元組碼。使用標準的 `pdb` 命令,例如 `bt/u/d/n/s/r` - 您可以向上遍歷 `pdb` 堆疊以檢查更多 Dynamo 內部。
import torch
from torch._dynamo.comptime import comptime
@torch.compile(backend="eager")
def f(x):
y = x + 1
comptime.breakpoint()
y = y + 1
return y
f(torch.ones(3, 3))
$ python playground.py
--Return--
> /data/users/williamwen/pytorch/torch/_dynamo/comptime.py(392)inner()->None
-> builtins.breakpoint()
(Pdb) ctx.print_bt()
File "/data/users/williamwen/pytorch/playground.py", line 7, in f
comptime.breakpoint()
(Pdb) ctx.print_locals()
x = FakeTensor(..., size=(3, 3))
y = FakeTensor(..., size=(3, 3))
(Pdb) bt
...
/data/users/williamwen/pytorch/torch/_dynamo/symbolic_convert.py(826)call_function()
-> self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
/data/users/williamwen/pytorch/torch/_dynamo/variables/misc.py(331)call_function()
-> func(ComptimeContext(tx))
> /data/users/williamwen/pytorch/torch/_dynamo/comptime.py(392)inner()->None
-> builtins.breakpoint()
(Pdb) ctx.print_graph()
def forward(self, L_x_: "f32[3, 3]"):
l_x_ = L_x_
# File: /data/users/williamwen/pytorch/playground.py:6 in f, code: y = x + 1
y: "f32[3, 3]" = l_x_ + 1; l_x_ = y = None
AOTAutograd#
AOTAutograd 錯誤通常難以除錯 - 我們建議直接提交一個問題。AOTAutograd 的日誌輸出主要有助於檢視 Inductor 的輸入是什麼。
TORCH_LOGS 選項摘要#
有用的 TORCH_LOGS 選項摘要如下:
選項 |
描述 |
|---|---|
+all |
輸出所有 |
+dynamo |
輸出 TorchDynamo 的除錯日誌。 |
+aot |
輸出 AOTAutograd 的除錯日誌。 |
+inductor |
輸出 TorchInductor 的除錯日誌。 |
dynamic |
輸出動態形狀相關的日誌。 |
graph_code |
輸出 Dynamo 生成的 FX 圖的 Python 程式碼。 |
graph_sizes |
輸出 Dynamo 生成的 FX 圖的張量大小。 |
trace_bytecode |
輸出 Dynamo 正在跟蹤的位元組碼指令以及 Dynamo 正在維護的符號直譯器堆疊。 |
trace_source |
輸出 Dynamo 當前正在跟蹤的原始原始碼行。 |
bytecode |
輸出 Dynamo 生成的位元組碼。 |
guards |
輸出生成的 Guard。 |
recompiles |
輸出重新編譯原因(僅第一個失敗的 Guard 檢查)。 |
recompiles_verbose |
輸出重新編譯時所有失敗的 Guard 檢查。 |
aot_graphs |
輸出 AOTAutograd 生成的圖。 |
aot_joint_graphs |
輸出 AOTAutograd 生成的前向-後向聯合圖。 |
output_code |
輸出 Inductor 生成的程式碼。 |
kernel_code |
輸出 Inductor 按核心生成的程式碼。 |
schedule |
輸出 Inductor 排程日誌。 |
perf_hints |
輸出 Inductor 效能提示日誌。 |
fusion |
輸出 Inductor fusion 日誌。 |
有關完整選項列表,請參閱 torch._logging 和 torch._logging.set_logs。