動態形狀#
建立日期:2023年5月19日 | 最後更新日期:2025年6月10日
程式碼: symbolic_shapes.py
另請參閱: 動態形狀手冊
動機#
深度學習編譯器通常只適用於靜態形狀,也就是說,它們生成的編譯程式只適用於單一的特定輸入形狀配置,並且在任何輸入形狀發生變化時都必須重新編譯。這個假設對於當今絕大多數常用的深度學習模型來說效果很好,但在少數情況下是不夠的。
某些維度,例如批次大小或序列長度,可能會有所不同。例如,執行自適應批次的推理服務將根據其批次視窗內收到的請求數量,以不同的批次大小執行推理請求。我們可能還希望僅將可變大小的序列填充到批次內的最大序列長度,而這個最大序列長度可能因批次而異。
某些模型表現出資料依賴的輸出形狀,也就是說,其輸出和中間張量的尺寸可能取決於實際輸入資料,而這些資料在不同執行之間可能會有所不同。例如,檢測模型可能首先生成可變數量的潛在邊界框,然後再執行一個更昂貴的影像識別模型來確定主題是否在邊界框內。邊界框的數量是資料依賴的。
資料依賴形狀的一個特別重要的案例發生在處理稀疏表示時,例如稀疏張量、鋸齒狀張量和圖神經網路。在所有這些情況下,要處理的資料量取決於問題的稀疏結構,而這通常會以資料依賴的方式變化。
在支援動態形狀時,我們選擇不支援動態秩程式,例如,輸入張量的維度會發生變化的程式,因為這種模式在實際深度學習程式中很少出現,而且它避免了對形狀符號列表進行歸納推理的需要。
精簡公共API#
PyTorch 2.1 中的預設動態行為是
PT2 預設假定所有內容都是靜態的
如果因為某個尺寸發生變化而重新編譯,我們將嘗試將該尺寸作為動態尺寸重新編譯(已更改的尺寸很可能在將來繼續變化)。這種泛化可能會失敗(例如,因為使用者程式碼對所討論的尺寸進行了條件分支,或者 PT2 中缺少動態形狀支援)。如果您想了解 PT2 為什麼對某些程式碼進行了過度特化,請執行
TORCH_LOGS=dynamic並查詢顯示何時新增 guard 以及原因的“eval”條目。如果您事先知道某個張量將是動態的,您可以使用
torch._dynamo.mark_dynamic(tensor, dim)跳過第一次重新編譯。如果您事先知道此維度可以取min和max值,您可以指定torch._dynamo.mark_dynamic(tensor, dim, min=min, max=max)如果您指定
torch.compile(dynamic=False),我們將關閉自動動態形狀的重編譯,並始終為每個不同的尺寸進行重編譯。反之,如果您指定torch.compile(dynamic=True),我們將嘗試儘可能地使所有內容動態化。這對於小型運算子很有用;如果您在大型模型上嘗試此操作,它將(1)很可能導致 PT2 崩潰,並且(2)執行緩慢而沒有好理由。您可以使用
TORCH_COMPILE_DYNAMIC_SOURCES環境變數或設定torch.compiler.config.dynamic_sources來白名單特定源以標記為動態。這對於具有圖中斷的大型模型特別有用,因為您可以跨圖中斷保持動態性,因為源名稱保持一致。您還可以使用它來標記整數為動態。格式是逗號分隔的源名稱列表,例如"L['x'], L['y']"。您也可以使用正則表示式,例如"L\['x.*'\], L\['y.*'\]")。此白名單優先於其他標誌,如dynamic=False、force_nn_module_property_static_shapes和force_parameter_static_shapes。有時找出要標記為動態的正確輸入可能很麻煩。如果您願意在第一個批次上承受效能損失,我們還有另一個經濟實惠的選項是 eager_then_compile 模式,它會為您推導動態性。有關更多詳細資訊,請參閱 torch.compiler.set_stance。
Guard 模型#
在考慮如何為 TorchDynamo 和 TorchInductor 新增動態形狀支援時,我們做了一個重要的設計決策:為了重用針對 PyTorch API 編寫的分解和其他現有程式碼,我們必須能夠跟蹤動態形狀。與可能捕獲條件的兩個分支的完全符號化系統不同,我們總是選擇一個分支,並在假設我們將來僅在對該分支做出相同選擇時使用此跟蹤的假設下特化我們的跟蹤。為此,我們為每個符號化大小維護一個“提示”,說明其在編譯時的具體值(由於 TorchDynamo 是即時編譯器,它總是知道實際的輸入大小。)當我們對張量進行條件判斷時,我們只需查閱提示即可確定採取哪個分支。
這極大地簡化了我們生成的符號形狀公式,但意味著我們有一個更復雜的 guard 管理系統。例如,考慮以下程式
def f(x, y):
z = torch.cat([x, y])
if z.size(0) > 2:
return z.mul(2)
else:
return z.add(2)
我們將用 TorchInductor 編譯的最終 IR 將是 torch.cat([x, y]).add(2) 或 torch.cat([x, y]).mul(2)(條件已展平),但要確定我們在哪個分支,我們需要知道 z 的大小,這是一箇中間值。由於 TorchDynamo 必須預先知道編譯跟蹤是否有效(我們不支援像某些 JIT 編譯器那樣的退出),我們必須能夠將 z.size(0) 表示為輸入 x.size(0) + y.size(0) 的表示式。這是透過為 PyTorch 中的所有運算子編寫元函式來實現的,這些元函式可以在不實際對節點執行計算的情況下將大小資訊傳播到張量的輸出。
總體架構#
符號形狀工作流程
當我們在 Dynamo 中開始編譯一個幀時,我們會分配一個 ShapeEnv(附加到 FakeTensorMode),它會跟蹤符號形狀狀態。
我們在進入時為張量分配符號大小(靜態還是動態是策略決策,有一些調整)。
我們透過運算子傳播符號大小,同時維護(1)FX IR 以便我們能夠忠實地匯出符號計算,以及(2)表示大小變數的 Sympy 表示式,以便我們能夠對它們進行推理。
當我們根據符號大小進行條件判斷時,無論是在 Dynamo 跟蹤還是在 Inductor 最佳化中,我們都會根據條件新增 guard。這些 guard 可以從 Python 和 C++ 中誘導。
這些 guard 可以對符號變數誘導進一步的簡化。例如,如果您斷言
s0 == 4,我們現在可以將s0的所有出現替換為4。當我們完成跟蹤和最佳化時,我們將所有這些 guard 安裝到編譯後的程式碼中;只有當所有 guard 都評估為 true 時,編譯後的程式碼才是可重用的。
重要檔案
C++ SymInt API:
c10/core/SymInt.h、SymFloat.h、SymBool.hPython SymInt API:
torch/__init__.py(查詢SymInt/SymFloat/SymBool)C++ 管道:
c10/core/SymNodeImpl.h、torch/csrc/utils/python_symnode.h、torch/csrc/jit/python/init.cppPython 基礎設施:
torch/fx/experimental/symbolic_shapes.py其他重要檔案:
torch/_subclasses/fake_tensor.py、torch/_meta_registrations.py、分解、PrimTorch 引用
精簡內部 API#
理解 Python 類層次結構
SymInt/SymFloat/SymBool:這些是使用者可見的類,模擬其對應的 int/float/bool。如果您將兩個 SymInt 相加,我們會給您一個新的 SymInt,它會符號化地跟蹤整數加法已發生。
SymNode:這是內部結構(可透過例如
symint.node訪問),它儲存實際的符號跟蹤資訊。SymNode 是型別擦除的;這使得表示混合型別操作更加方便。請注意,技術上您不必從 SymInt 呼叫 Python SymNode;例如,XLA 的 C++SymNodeImpl將取代 SymNode。ShapeEnv:每次編譯的上下文狀態,它跟蹤我們到目前為止累積的所有自由符號和 guard。每個 SymNode 都記錄其 ShapeEnv(但反之則不然;只有當 SymNode 參與 guard 時才會使用它們)。
C++ 也很相似
c10::SymInt/SymFloat/SymBool:使用者可見的類,模擬 int/float/bool。
c10::SymNode/SymNodeImpl:類似於 SymNode
C++ 中沒有 ShapeEnv;為了方便除錯,整個符號推理機制都在 Python 中。
當您編寫可以用 make_fx 跟蹤的程式碼時,它必須能夠處理其中的 SymInt/SymFloat/SymBool 流。 動態形狀手冊 提供了一些關於如何執行此操作的指導。
未備份的 SymInt#
為了解析控制流,我們檢查符號整數的提示(即實際值)來確定要轉到哪個分支。但是,在某些情況下,我們可能沒有提示:所謂的未備份符號整數是在資料依賴操作(如 .nonzero() 或 .item())中出現的大小變數。對這些符號整數執行控制流是非法的,因此我們必須在這些操作上進行圖中斷。
如果 naively 實現,這過於嚴格:大多數 PyTorch 程式在嘗試對未備份的符號整數執行任何操作時都會立即失敗。以下是對使其真正工作的最重要的增強功能:
在張量建立時,PyTorch 會預先計算有關張量的許多資料;例如,如果您使用
empty_strided建立張量,我們會主動排序跨步並確定張量是否不重疊且密集。排序會產生大量 guard。但是,更常見的是直接使用像empty這樣的更高階 API 來生成張量,後者保證會生成非重疊且密集的張量。我們修改了 PyTorch 以避免不必要地重新計算這些屬性。即使需要非平凡的計算,有時某個屬性根本不會被查詢。將這些預先計算的屬性設為惰性使我們能夠避免在未備份的符號整數上設定 guard,除非確實需要。
整數張量中的資料通常不確定是否為非負數。但是,我們提供了一個 API
constrain_range,使用者可以透過它指定大小的上下界由已知限制。
與動態 API 類似,存在相應的未備份 API:即您可以使用 mark_unbacked 而不是 mark_dynamic,並使用 TORCH_COMPILE_UNBACKED_SOURCES 而不是 TORCH_COMPILE_DYNAMIC_SOURCES 來告訴編譯器將輸入標記為未備份。
在 PT2 的未來版本(PT2.1 之後)中,我們將擴充套件我們的推理系統,以根據用法推斷未備份的符號整數是 size-like 的。例如,如果您將 .item() 呼叫結果傳遞給像 torch.empty 這樣的工廠函式,我們將自動推斷結果是 size(因為如果不是,它將失敗。)此假設將在執行時得到驗證,如果未滿足,將引發錯誤。