PyTorch 2.0 NNModule 支援#
創建於: 2023年4月06日 | 最後更新於: 2025年6月10日
作者: Will Constable
torch.compile 對 torch.nn.Module 物件有特殊處理,它會以不同於追蹤任意 Python 類的方式來追蹤它們,目的是透過做出關於結構的假設來生成更快的程式碼。
本文件描述了由於這種專業化而產生的一些權衡或邊緣情況。
NNModule Hook 支援#
之前,torch.compile 不支援 nn.Modules 上的 hooks,如果註冊了 hooks,它們在編譯後的程式中將被忽略。事實上,許多使用者根本不使用 nn.Module hooks,或者只在除錯工作流中使用它們,但存在將 nn.Module hooks 與 torch.compile 結合使用的有效用例。
透過 nn.Module.call 實現編排的 Hooks 包括 _forward_pre_hooks、forward_hooks、_backward_pre_hooks 和 _backward_hooks,並將被引用為“call hooks”。這些 hooks 在 torch.compile 中得到部分支援,但存在以下限制。
另一類 Hooks 包括 _state_dict_hooks 及其 pre 和 load_ 變體,它們仍然不受 torch.compile 支援。
nn.Module.__call__ Hooks 用法和限制#
預設情況下,torch.compile 會追蹤 nn.Module.__call__ 的內容,這意味著它會遇到並執行前向/預前向 hooks。如果您在呼叫 torch.compile 之前註冊了 hooks,並且之後不移除或更改 hooks,那麼您的用例應該得到預設支援。
後向/預後向 hooks 通常也得到支援,但有類似的注意事項:目前在 dynamo 中訪問 backward_hooks 字典時會發生圖中斷 (graph-breaks),這可能透過一些工作來避免。圖中斷也會影響後向 hooks 的觸發時機,因為圖段被作為 autograd-functions 執行,它們會同時產生所有梯度。假設 dynamo 可以避免因存在後向 hooks 而導致圖中斷,我們仍然期望一系列模組的後向 hooks 在整個編譯圖的後向執行後一起觸發。
“允許模組”上的 hooks torch.compile 特別處理常見的模組,如 torch.conv,以及難以追蹤的模組,允許它們在 dynamo 圖中被不透明地呼叫,而不是被 dynamo 追蹤。對於這類模組,hooks 當前會觸發圖中斷,以便受影響的模組在 dynamo 外部執行。根據模型,這可能會導致顯著的效能下降,並且需要額外的工作來改進此支援。
skip_nnmodule_hook_guards 預設情況下,torch._dynamo.config.skip_nnmodule_hook_guards 設定為 True,這意味著不會在每個 nn.Module hook 字典上安裝 guards,從而透過減少 guard 執行時間來提高執行時效能,但代價是無法在編譯後發現任何 hook 字典被更改。
如果您希望在編譯後能夠移除或修改 hooks,並且讓 torch.compile 做出適當的反應(透過重新編譯),那麼您需要將 skip_nnmodule_hook_guards=False,並預期由於增加了 guards 而產生的執行時開銷。
TODO:確認後向/預後向 hooks 是否工作,並相應地記錄。
state_dict Hooks#
State dict hooks 尚未在 torch.compile 中得到支援。
TODO:如果 hook 觸發圖中斷,則發出一次警告。如果存在 hook,則發出一次警告以指向本文件。