評價此頁

Fake tensor#

建立日期:2023 年 5 月 19 日 | 最後更新日期:2025 年 6 月 13 日

程式碼: fake_tensor.py

動機#

在進行 Dynamo 符號求值和編譯器傳遞時,我們經常希望能夠執行張量運算來了解輸出的尺寸/資料型別/裝置,而無需實際執行這些運算(或破壞已有的張量),因為那樣會更慢(如果進行大量計算)並且佔用大量記憶體(如果您的編譯器在編譯程式時需要使用 GPU 記憶體,那將非常糟糕)。Fake tensor 在所有方面都類似於真實張量,只是它實際上沒有任何資料。例如,在進行 Dynamo 追蹤時,我們需要追蹤使用者的 Tensor 程式碼並回答關於中間結果的問題(例如,如果使用者對中間張量進行條件判斷)。沒有 fake tensor,我們將無法獲得這些查詢的準確資訊。

類似地,假設您想儲存張量的元資料,例如儲存在 FX IR 節點上(meta[‘val’])。您可以直接在節點上儲存一個 fake tensor,它將為您提供張量所需的所有元資料,包括您可能未考慮到的細微之處(例如,別名關係)。

整體架構#

所有 fake tensor 都與 FakeTensorMode 相關聯。因為 fake tensor 的主要用例是對真實張量進行分析,所以一般的工作流程是:您有一堆真實張量,分配一個 FakeTensorMode,然後使用 from_real_tensor 將所有這些真實張量轉換為 fake tensor,然後對 fake tensor 進行操作。特別是,FakeTensorMode 維護一個持久的對映表,將張量(和儲存)對映到相同的儲存。如果您多次 fakeify 同一個張量,您將得到相同的 fake tensor;如果您 fakeify 兩個相互別名的張量,您將得到兩個別名相同 fake storage 的 fake tensor。FakeTensors 是 tensor subclass,因此如果您對它們進行運算,您將自動獲得一個 fake tensor,但通常您希望在 FakeTensorMode 處於活動狀態時對 fake tensor 進行運算(例如,如果您正在執行 FX 傳遞);張量運算會自動開啟 fake tensor 模式並重試。

Fake tensor 表示為 meta tensor 的 __torch_dispatch__ tensor subclass。這意味著在底層,fake tensor 是 meta device tensors;然後它們使用額外的可擴充套件性鉤子,特別是 dispatch_device,來謊報張量的實際裝置。這是早期 fake tensor 中比較容易出錯的部分:有時,fake tensor 會過於擅長偽裝成 CPU/CUDA 等,最終會導致 CPU 核心被呼叫,而 fake tensor 試圖解引用資料指標,這顯然行不通。如果您在 fake tensor 程式碼中遇到段錯誤,這是您應該首先檢查的地方:C++ 回溯是否在 CPU 核心(意外!)還是 meta 核心(預期!)中?Meta 核心類似於真實核心,但它所做的只是分配輸出,而不執行任何資料計算。

Tensor subclass 必須定義如何實現各種運算。這是通用的 fake tensor 過程:

  • 在輸入 fake tensors 上執行 meta 核心,將它們重新解釋為 meta tensors。這是透過一個特殊的上下文管理器 in_kernel_invocation_manager 來完成的,它指示 PyTorch 將 fake tensors 視為其底層的 meta tensors,而不是“解包” fake tensors 為 meta tensors(fake tensor 是 meta tensor)。Fake tensors 的表示方式是為了避免同步兩組元資料(meta tensor 的元資料和 fake tensor 的元資料);“is a”關係確保只有一組規範的元資料副本。

  • 如果您是工廠函式,則會改用 device='meta' 呼叫底層工廠函式。

  • 將生成的 meta tensor 轉換為 fake tensor,計算張量的輸出裝置應該是什麼(這通常是微不足道的,但有時並非如此,例如,CPU 標量提升,或裝置轉換運算)。

API:重要部分#

非 PT2 用途(檢視 test/test_fake_tensor.py 以獲取更多示例)

# Create a fake mode
from torch._subclasses.fake_tensor import FakeTensorMode
fake_mode = FakeTensorMode()
converter = fake_mode.fake_tensor_converter
# Fakeify some real tensors
fake_x = converter.from_real_tensor(fake_mode, x)
with fake_mode:
    # Do some operations on the fake tensors
    fake_y = fake_x * 2
    # Factory operations automatically get fakeified in the context manager
    fake_z = torch.empty(20)

問:為什麼輸入是真實的張量?

答:在 PT2 上下文中,這是因為您通常是即時編譯,因此對於要編譯的圖的所有輸入,您已經擁有“真實”的輸入,因為您在程式執行時進行編譯。

PT2 AOTAutograd 之前的使用(這很不尋常,您可能不想這樣做)

# Fake mode is not enabled!
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(args)
# if fake_mode isn't None
converter = fake_mode.fake_tensor_converter
fake_args = [converter.from_real_tensor(fake_mode, arg) for arg in args]
with fake_mode:
    ... # do stuff with the fake args, if needed ...

detect_fake_mode 將在多個位置搜尋以嘗試找到與生命週期關聯的“那個” fake tensor 模式。通常它會從追蹤上下文中提取。

PT2 AOTAutograd 之後的使用

# Fake mode is enabled! example_inputs is typically fake already
# TODO: we probably want to change this
# Still do this to access fake mode
fake_mode = detect_fake_mode(example_inputs)
# But in general you don't have to turn it on

其他有用資訊

from torch._subclasses.fake_tensor import unset_fake_temporarily
with unset_fake_temporarily():
    ... # fake mode is disabled here, you can do real tensor compute

何時可能希望停用 fake tensor 模式?通常您不希望這樣做。我們發現一個有用的細微之處是實現 fake tensor 上的常量傳播:在這種情況下,即使在 fake tensor 模式下,我們也需要進行一些實際的張量計算。

import FakeTensorProp from torch.fx.passes.fake_tensor_prop
gm: GraphModule
real_inputs: List[Tensor]
FakeTensorProp(gm).propagate(*real_inputs)
# This will populate meta['val'] on all the FX nodes with a fake tensor
# or if you have a preexisting fake mode, you should use it
FakeTensorProp(gm, mode=fake_mode).propagate(*real_inputs)
# There is also propagate_dont_convert_inputs if your inputs are already fake
fake_inputs: List[FakeTensor]
FakeTensorProp(gm, mode=fake_mode).propagate_dont_convert_inputs(*fake_inputs)

細節#

自動轉換還是不自動轉換?最初,FakeTensorMode 不會在 FakeTensorMode 區域內嘗試計算時自動 fakeify 真實張量。這樣做的動機是為了防止以下“腳踏實地”的陷阱:

with FakeTensorMode():
    real_tensor.t_()

這段程式碼應該做什麼?如果我們實際上修改了真實張量上的元資料,那將是令人驚訝的。但同時,也沒有明顯的建立 FakeTensor 的機會。因此,我們保守地決定引發錯誤:“在 FakeTensorMode 中呼叫具有非 Fake Tensor 輸入的運算子尚不支援。請先將所有 Tensor 轉換為 FakeTensors。”

在實踐中,這個錯誤非常煩人。例如,假設您有一個真實的 nn.Module,並希望透過它傳遞 fake tensor。您需要某種方式來 fakeify nn.Module。這促成了 FakeCopyMode 的出現。

最終,我們放棄了,並添加了自動 fakeification。然而,在許多 FakeTensorMode 的用法中,這仍然不是預設啟用的。

Fake tensor 上的元資料變異 如果您有一個 fake tensor,並且對其執行 t_(),fake tensor 上的元資料會發生變化。表面上看這是合理的,但有時您也希望將 fake tensor 作為元資料儲存在 FX 節點上;變異 fake tensor 是不好的,因為它會使舊的元資料失效!

事實上,這裡存在一個根本性的矛盾,即 fake tensors 維護著關於張量的極其準確的元資料,直到包括物件身份。如果 FX 圖中的物件元資料隨時間變化,實際上沒有辦法表示這種隨時間的變化。大多數時候,我們對 FX 的嚴肅分析是在函式化圖上進行的,而這些圖沒有這個特性,但偶爾您需要在非函式化圖上進行分析。也許將 fake tensor 放入 meta[‘val’] 是個錯誤。

關於 tensor subclass#

Fake tensor 同時使用了 subclass 和 mode tensor subclass 模式,其中 FakeTensor.__torch_dispatch__ 啟用了與 fake tensor 關聯的 FakeTensorMode,然後重新排程(依賴 FakeTensorMode 來完成繁重的工作)。如果 fake tensor 運算收到一個它不識別的 subclass 引數,它將返回 NotImplemented,讓另一個 subclass 有機會先執行(希望將其“脫糖”為普通張量運算),然後再重試。這可能導致無限迴圈。

每個單獨的運算子是如何實現的?#

不幸的是,任何給定的運算子可能實現的地方非常複雜。以下是一些重要的需要了解的情況:

  • Tensor subclasses 支援有限的常量傳播,如果元素數量非常少(這有助於處理一些我們立即呼叫 item() 的情況)。

  • 出於效能原因,我們為某些運算子提供了一些快速路徑實現,這些實現完全在 fake tensor 中完成。

  • 如果您使用 @custom_op 來生成自定義張量,這些張量將直接向 fake tensor 註冊 impl_abstract。

  • Fake tensor 本身對裝置轉換運算有一些硬編碼的特殊情況。

  • 如果沒有 meta 實現也沒有分解,我們將生成真實的零填充張量並嘗試直接執行該運算子以找出結果。這可能會導致段錯誤,如果該運算子嘗試使用資料進行索引,因此我們預設不對自定義運算子啟用此功能。

轉換器是如何工作的?#

因為 fake tensor 用於對張量的精確屬性非常敏感的場景,所以 fake tensor 的轉換非常小心,會保留 leaf-ness、requires_grad-ness、別名以及許多其他屬性。大部分繁重的工作都在 MetaConverter 中。

效能特徵#

您可能會認為 fake tensor 速度很快,因為它們不進行任何張量計算。但在小張量尺寸下,我們完全受限於開銷,而且,fake tensor 是用 Python 編寫的,我們經常為單個張量運算執行大量工作(因為它們是作為分解實現的)。所以實際上,fake tensor 速度相當慢,尤其是在涉及符號形狀時。目前 fake tensor 有兩個重要的快速路徑,在實踐中效果顯著:

  • Pointwise 運算不透過 PrimTorch 分解,而是我們手工編碼了它們的傳播規則。

  • 如果可能,我們應該這樣做。

Fake tensor 的 fake tensor?#

人們有興趣將 fake tensor 作為使用者輸入傳送到 PT2 堆疊,這意味著我們需要能夠建立一個 fake tensor 的 fake tensor。這目前並沒有真正得到支援,但也許做起來不會太難。

與動態形狀的互動#

每個 FakeTensorMode 都包含一個 ShapeEnv,用於跟蹤所有符號形狀資訊。它們的生命週期通常是繫結的:它們一起存在,一起消亡。

因為 FakeTensorMode 有一個 ShapeEnv(但 meta 實現沒有),所以依賴於資料的、需要分配未備份的 SymInt 的 meta 函式存在於 fake tensor 中。Fake tensor 還負責 memoize 未備份的 SymInt,因此,例如,如果您對同一個 fake tensor 呼叫 nonzero() 兩次,您會得到相同的符號大小。