處理重新編譯#
建立時間: 2025年7月29日 | 最後更新時間: 2025年7月29日
為了確保 torch.compile 的正確性,必須進行重新編譯,但這可能導致編譯時間顯著增加。因此,在保證正確性的前提下最大程度地減少重新編譯對於縮短編譯時間至關重要。
您可以使用 tlparse 或 TORCH_LOGS=recompiles 來檢視重新編譯及其原因。
是否啟用了動態形狀?#
在下面的示例中,我們因為形狀不匹配而進行重新編譯。
@torch.compile
def fn(x):
return x + 1
fn(torch.ones(3))
fn(torch.ones(4))
Recompiling function fn in /tmp/ipykernel_990/2479206322.py:1
triggered by the following guard failure(s):
- 0/0: tensor 'x' size mismatch at index 0. expected 3, actual 4
tensor([2., 2., 2., 2.])
請確保 torch.compile 的 dynamic 選項未設定為 False。預設選項 dynamic=None 將僅在首次編譯後嘗試使用動態形狀。您可以將 dynamic 設定為 True,以便儘可能地提前進行動態編譯。
@torch.compile(dynamic=True)
def gn(x):
return x + 1
gn(torch.ones(3))
gn(torch.ones(4))
tensor([2., 2., 2., 2.])
有關動態形狀的更多資訊,包括如何處理由於動態形狀引起的錯誤/重新編譯,請參閱 動態形狀手冊。
將常量封裝到張量中#
預設情況下,int / float 變數被視為常量,並基於其精確值進行保護。在下面的示例中,每次函式呼叫都會導致一次重新編譯。
@torch.compile
def fn(x, c):
return x + c
for i in range(5):
fn(torch.ones(i), 0.5 + i)
Recompiling function fn in /tmp/ipykernel_990/3647755280.py:1
triggered by the following guard failure(s):
- 2/0: c == 0.5 # return x + c # mp/ipykernel_990/3647755280.py:3 in fn
Recompiling function fn in /tmp/ipykernel_990/3647755280.py:1
triggered by the following guard failure(s):
- 2/1: tensor 'x' size mismatch at index 0. expected 1, actual 2
- 2/0: c == 0.5 # return x + c # mp/ipykernel_990/3647755280.py:3 in fn
特別是,對於學習率排程器,使用常量初始化可能會導致重新編譯。
mod = torch.nn.Linear(3, 3)
opt = torch.optim.Adam(mod.parameters(), lr=0.01)
sched = torch.optim.lr_scheduler.ExponentialLR(opt, 0.9)
@torch.compile
def gn(inp):
opt.zero_grad(True)
out = mod(inp).sum()
out.backward()
opt.step()
sched.step()
for i in range(5):
gn(torch.ones(3, 3))
Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
Recompiling function step in /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/optim/adam.py:213
triggered by the following guard failure(s):
- 7/0: self.param_groups[0]['lr'] == 0.01 # for group in self.param_groups: # optim/adam.py:228 in step
在以上兩個示例中,我們可以將 float 變數封裝到張量中,以防止重新編譯。
# first example
for i in range(5):
fn(torch.ones(i), torch.tensor(0.5 + i))
# second example
opt = torch.optim.Adam(mod.parameters(), lr=torch.tensor(0.01))
sched = torch.optim.lr_scheduler.ExponentialLR(opt, torch.tensor(0.9))
for i in range(5):
gn(torch.ones(3, 3))
Recompiling function fn in /tmp/ipykernel_990/3647755280.py:1
triggered by the following guard failure(s):
- 0/0: tensor 'x' size mismatch at index 0. expected 0, actual 1
Recompiling function fn in /tmp/ipykernel_990/3647755280.py:1
triggered by the following guard failure(s):
- 0/1: tensor 'x' size mismatch at index 0. expected 1, actual 2
- 0/0: tensor 'x' size mismatch at index 0. expected 0, actual 2
更改快取大小限制#
函式可以被重新編譯的次數是有限制的,這由 torch._dynamo.config.cache_size_limit 和 torch._dynamo.config.accumulated_cache_size_limit 決定(這兩個值之間的確切區別在 torch/_dynamo/cache_size.py 中有詳細說明)。如果達到 Dynamo 快取限制,那麼所有未來的編譯嘗試**都將導致函式被跳過(即時執行)**。如果保護條件透過,Dynamo 仍會嘗試使用先前編譯的位元組碼來進行後續函式呼叫。請注意,在達到重新編譯限制的情況下,**所有巢狀函式呼叫都將被跳過**(Dynamo 會嘗試使用先前編譯的位元組碼來處理巢狀函式)。Dynamo 還會發出警告,其中包含受影響的函式以及觸發了哪個限制。在下面的示例中,每次函式呼叫都會導致一次重新編譯嘗試。當達到快取大小限制(預設為 8)時,我們會停止嘗試重新編譯。(注意,為了演示的目的,我們已將 dynamic 設定為 False 以強制每次都進行重新編譯)。
@torch.compile(dynamic=False)
def fn(x):
return x + 1
for i in range(1, 10):
# recompile every time due to dynamic=False
fn(torch.ones(i))
Recompiling function fn in /tmp/ipykernel_990/3054308037.py:1
triggered by the following guard failure(s):
- 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 2
Recompiling function fn in /tmp/ipykernel_990/3054308037.py:1
triggered by the following guard failure(s):
- 8/1: tensor 'x' size mismatch at index 0. expected 2, actual 3
- 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 3
Recompiling function fn in /tmp/ipykernel_990/3054308037.py:1
triggered by the following guard failure(s):
- 8/2: tensor 'x' size mismatch at index 0. expected 3, actual 4
- 8/1: tensor 'x' size mismatch at index 0. expected 2, actual 4
- 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 4
Recompiling function fn in /tmp/ipykernel_990/3054308037.py:1
triggered by the following guard failure(s):
- 8/3: tensor 'x' size mismatch at index 0. expected 4, actual 5
- 8/2: tensor 'x' size mismatch at index 0. expected 3, actual 5
- 8/1: tensor 'x' size mismatch at index 0. expected 2, actual 5
- 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 5
Recompiling function fn in /tmp/ipykernel_990/3054308037.py:1
triggered by the following guard failure(s):
- 8/4: tensor 'x' size mismatch at index 0. expected 5, actual 6
- 8/3: tensor 'x' size mismatch at index 0. expected 4, actual 6
- 8/2: tensor 'x' size mismatch at index 0. expected 3, actual 6
- 8/1: tensor 'x' size mismatch at index 0. expected 2, actual 6
- 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 6
Recompiling function fn in /tmp/ipykernel_990/3054308037.py:1
triggered by the following guard failure(s):
- 8/5: tensor 'x' size mismatch at index 0. expected 6, actual 7
- 8/4: tensor 'x' size mismatch at index 0. expected 5, actual 7
- 8/3: tensor 'x' size mismatch at index 0. expected 4, actual 7
- 8/2: tensor 'x' size mismatch at index 0. expected 3, actual 7
- 8/1: tensor 'x' size mismatch at index 0. expected 2, actual 7
- 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 7
Recompiling function fn in /tmp/ipykernel_990/3054308037.py:1
triggered by the following guard failure(s):
- 8/6: tensor 'x' size mismatch at index 0. expected 7, actual 8
- 8/5: tensor 'x' size mismatch at index 0. expected 6, actual 8
- 8/4: tensor 'x' size mismatch at index 0. expected 5, actual 8
- 8/3: tensor 'x' size mismatch at index 0. expected 4, actual 8
- 8/2: tensor 'x' size mismatch at index 0. expected 3, actual 8
- 8/1: tensor 'x' size mismatch at index 0. expected 2, actual 8
- 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 8
Recompiling function fn in /tmp/ipykernel_990/3054308037.py:1
triggered by the following guard failure(s):
- 8/7: tensor 'x' size mismatch at index 0. expected 8, actual 9
- 8/6: tensor 'x' size mismatch at index 0. expected 7, actual 9
- 8/5: tensor 'x' size mismatch at index 0. expected 6, actual 9
- 8/4: tensor 'x' size mismatch at index 0. expected 5, actual 9
- 8/3: tensor 'x' size mismatch at index 0. expected 4, actual 9
- 8/2: tensor 'x' size mismatch at index 0. expected 3, actual 9
- 8/1: tensor 'x' size mismatch at index 0. expected 2, actual 9
- 8/0: tensor 'x' size mismatch at index 0. expected 1, actual 9
torch._dynamo hit config.recompile_limit (8)
function: 'fn' (/tmp/ipykernel_990/3054308037.py:1)
last reason: 8/7: tensor 'x' size mismatch at index 0. expected 8, actual 9
To log all recompilation reasons, use TORCH_LOGS="recompiles".
To diagnose recompilation issues, see https://pytorch.com.tw/docs/stable/torch.compiler_troubleshooting.html
如果您知道重新編譯次數有一個合理的常量上限,您可以提高快取大小限制。如果重新編譯的成本超過了編譯的好處,那麼您可以考慮降低快取大小限制。
torch._dynamo.config.cache_size_limit = 16
@torch.compile(dynamic=False)
def gn(x):
return x + 1
for i in range(1, 10):
gn(torch.ones(i))
Recompiling function gn in /tmp/ipykernel_990/887097224.py:2
triggered by the following guard failure(s):
- 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 2
Recompiling function gn in /tmp/ipykernel_990/887097224.py:2
triggered by the following guard failure(s):
- 9/1: tensor 'x' size mismatch at index 0. expected 2, actual 3
- 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 3
Recompiling function gn in /tmp/ipykernel_990/887097224.py:2
triggered by the following guard failure(s):
- 9/2: tensor 'x' size mismatch at index 0. expected 3, actual 4
- 9/1: tensor 'x' size mismatch at index 0. expected 2, actual 4
- 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 4
Recompiling function gn in /tmp/ipykernel_990/887097224.py:2
triggered by the following guard failure(s):
- 9/3: tensor 'x' size mismatch at index 0. expected 4, actual 5
- 9/2: tensor 'x' size mismatch at index 0. expected 3, actual 5
- 9/1: tensor 'x' size mismatch at index 0. expected 2, actual 5
- 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 5
Recompiling function gn in /tmp/ipykernel_990/887097224.py:2
triggered by the following guard failure(s):
- 9/4: tensor 'x' size mismatch at index 0. expected 5, actual 6
- 9/3: tensor 'x' size mismatch at index 0. expected 4, actual 6
- 9/2: tensor 'x' size mismatch at index 0. expected 3, actual 6
- 9/1: tensor 'x' size mismatch at index 0. expected 2, actual 6
- 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 6
Recompiling function gn in /tmp/ipykernel_990/887097224.py:2
triggered by the following guard failure(s):
- 9/5: tensor 'x' size mismatch at index 0. expected 6, actual 7
- 9/4: tensor 'x' size mismatch at index 0. expected 5, actual 7
- 9/3: tensor 'x' size mismatch at index 0. expected 4, actual 7
- 9/2: tensor 'x' size mismatch at index 0. expected 3, actual 7
- 9/1: tensor 'x' size mismatch at index 0. expected 2, actual 7
- 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 7
Recompiling function gn in /tmp/ipykernel_990/887097224.py:2
triggered by the following guard failure(s):
- 9/6: tensor 'x' size mismatch at index 0. expected 7, actual 8
- 9/5: tensor 'x' size mismatch at index 0. expected 6, actual 8
- 9/4: tensor 'x' size mismatch at index 0. expected 5, actual 8
- 9/3: tensor 'x' size mismatch at index 0. expected 4, actual 8
- 9/2: tensor 'x' size mismatch at index 0. expected 3, actual 8
- 9/1: tensor 'x' size mismatch at index 0. expected 2, actual 8
- 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 8
Recompiling function gn in /tmp/ipykernel_990/887097224.py:2
triggered by the following guard failure(s):
- 9/7: tensor 'x' size mismatch at index 0. expected 8, actual 9
- 9/6: tensor 'x' size mismatch at index 0. expected 7, actual 9
- 9/5: tensor 'x' size mismatch at index 0. expected 6, actual 9
- 9/4: tensor 'x' size mismatch at index 0. expected 5, actual 9
- 9/3: tensor 'x' size mismatch at index 0. expected 4, actual 9
- 9/2: tensor 'x' size mismatch at index 0. expected 3, actual 9
- 9/1: tensor 'x' size mismatch at index 0. expected 2, actual 9
- 9/0: tensor 'x' size mismatch at index 0. expected 1, actual 9
圖中斷以降低重新編譯成本#
如果一個大型圖正在重新編譯並導致高編譯時間,您可以故意引入一個圖中斷來降低重新編譯成本,但代價是引入效能損失。
def very_large_function(x):
return x + 1
@torch.compile(dynamic=False)
def fn(x, c):
y = very_large_function(x) # recompiled every time
return y + c
for i in range(1, 5):
fn(torch.ones(3), i)
@torch.compile(dynamic=False)
def gn(x, c):
y = very_large_function(x) # compiled only once
torch._dynamo.graph_break()
return y + c # recompiled every time
for i in range(1, 5):
gn(torch.ones(3), i)
Recompiling function fn in /tmp/ipykernel_990/2876112129.py:4
triggered by the following guard failure(s):
- 10/0: c == 1 # return y + c # mp/ipykernel_990/2876112129.py:7 in fn
Recompiling function fn in /tmp/ipykernel_990/2876112129.py:4
triggered by the following guard failure(s):
- 10/1: c == 2 # return y + c # mp/ipykernel_990/2876112129.py:7 in fn
- 10/0: c == 1 # return y + c # mp/ipykernel_990/2876112129.py:7 in fn
Recompiling function fn in /tmp/ipykernel_990/2876112129.py:4
triggered by the following guard failure(s):
- 10/2: c == 3 # return y + c # mp/ipykernel_990/2876112129.py:7 in fn
- 10/1: c == 2 # return y + c # mp/ipykernel_990/2876112129.py:7 in fn
- 10/0: c == 1 # return y + c # mp/ipykernel_990/2876112129.py:7 in fn
Recompiling function torch_dynamo_resume_in_gn_at_15 in /tmp/ipykernel_990/2876112129.py:15
triggered by the following guard failure(s):
- 12/0: c == 1 # return y + c # recompiled every time # mp/ipykernel_990/2876112129.py:16 in torch_dynamo_resume_in_gn_at_15
Recompiling function torch_dynamo_resume_in_gn_at_15 in /tmp/ipykernel_990/2876112129.py:15
triggered by the following guard failure(s):
- 12/1: c == 2 # return y + c # recompiled every time # mp/ipykernel_990/2876112129.py:16 in torch_dynamo_resume_in_gn_at_15
- 12/0: c == 1 # return y + c # recompiled every time # mp/ipykernel_990/2876112129.py:16 in torch_dynamo_resume_in_gn_at_15
Recompiling function torch_dynamo_resume_in_gn_at_15 in /tmp/ipykernel_990/2876112129.py:15
triggered by the following guard failure(s):
- 12/2: c == 3 # return y + c # recompiled every time # mp/ipykernel_990/2876112129.py:16 in torch_dynamo_resume_in_gn_at_15
- 12/1: c == 2 # return y + c # recompiled every time # mp/ipykernel_990/2876112129.py:16 in torch_dynamo_resume_in_gn_at_15
- 12/0: c == 1 # return y + c # recompiled every time # mp/ipykernel_990/2876112129.py:16 in torch_dynamo_resume_in_gn_at_15