Autograd 機制#
創建於: 2017年1月16日 | 最後更新於: 2025年6月16日
本文件將概述 autograd 的工作原理和操作記錄方式。深入理解這些細節並非強制要求,但我們推薦您熟悉它們,因為這將幫助您編寫更高效、更簡潔的程式,並有助於除錯。
Autograd 如何編碼歷史#
Autograd 是一個反向自動微分系統。從概念上講,autograd 在執行操作時,會記錄所有建立資料的操作,形成一個有向無環圖(DAG),其葉節點是輸入張量,根節點是輸出張量。透過追蹤這個圖從根到葉的路徑,您可以使用鏈式法則自動計算梯度。
在內部,autograd 將這個圖表示為一系列 Function 物件(實際上是表示式)組成的圖,這些物件可以被 apply() 來計算圖的求值結果。在進行前向傳播時,autograd 會同時執行請求的計算,並構建一個表示計算梯度的函式圖(每個 torch.Tensor 的 .grad_fn 屬性是進入此圖的入口)。當前向傳播完成後,我們在後向傳播中評估這個圖來計算梯度。
需要注意的是,這個圖在每次迭代時都會從頭開始重建,這正是它允許使用任意 Python 控制流語句(這些語句可以在每次迭代中改變圖的整體形狀和大小)的原因。您不必在啟動訓練前就編碼所有可能的路徑——您執行什麼,就對什麼進行微分。
儲存的張量#
某些操作在執行後向傳播時,需要在前向傳播過程中儲存中間結果。例如,函式 儲存輸入 以計算梯度。
在定義自定義 Python Function 時,您可以使用 save_for_backward() 在前向傳播時儲存張量,並在後向傳播時使用 saved_tensors 檢索它們。更多資訊請參閱 擴充套件 PyTorch。
對於 PyTorch 定義的操作(例如 torch.pow()),張量會根據需要自動儲存。您可以(出於教育或除錯目的)透過查詢以 _saved 為字首的屬性來探索特定 grad_fn 儲存了哪些張量。
x = torch.randn(5, requires_grad=True)
y = x.pow(2)
print(x.equal(y.grad_fn._saved_self)) # True
print(x is y.grad_fn._saved_self) # True
在之前的程式碼中,y.grad_fn._saved_self 指向與 x 相同的 Tensor 物件。但這並非總是如此。例如:
x = torch.randn(5, requires_grad=True)
y = x.exp()
print(y.equal(y.grad_fn._saved_result)) # True
print(y is y.grad_fn._saved_result) # False
在底層,為了防止引用迴圈,PyTorch 在儲存張量時對其進行了*打包*,並在讀取時將其*解包*到一個不同的張量中。在這裡,您透過訪問 y.grad_fn._saved_result 獲得的張量物件與 y 是不同的張量物件(但它們仍然共享相同的儲存)。
張量是否會被打包成不同的張量物件,取決於它是否是其自身 grad_fn 的輸出,這是一個實現細節,可能會發生更改,使用者不應依賴它。
您可以透過 儲存張量的鉤子 來控制 PyTorch 的打包/解包行為。
不可微函式的梯度#
自動微分的梯度計算僅在所使用的每個基本函式可微時才有效。不幸的是,實踐中使用的許多函式不具備此屬性(例如 0 處的 relu 或 sqrt)。為了儘量減少不可微函式的影響,我們透過以下規則順序定義基本操作的梯度:
如果函式在該點可微,則使用該點的梯度。
如果函式是凸函式(至少在區域性),則使用最小范數的次梯度。
如果函式是凹函式(至少在區域性),則使用最小范數的超梯度(考慮 -f(x) 並應用上一條)。
如果函式在該點有定義,則透過連續性定義該點的梯度(注意這裡可能出現
inf,例如對於sqrt(0))。如果存在多個可能的值,則任意選擇一個。如果函式未定義(例如
sqrt(-1)、log(-1)或輸入為NaN時的大多數函式),則使用的梯度值是任意的(我們也可能丟擲錯誤,但並非保證)。大多數函式將使用NaN作為梯度,但出於效能原因,某些函式將使用其他值(例如log(-1))。如果函式不是確定性對映(即它不是一個數學函式),它將被標記為不可微。這將在後向傳播中導致錯誤(如果在
no_grad環境外用於需要 grad 的張量)。
Autograd 中的除零錯誤#
在 PyTorch 中進行除零運算(例如 x / 0)時,前向傳播將遵循 IEEE-754 浮點算術生成 inf 值。雖然這些 inf 值可以在計算最終損失之前被遮蔽掉(例如透過索引或掩碼),但 autograd 系統仍然會追蹤並對整個計算圖進行微分,包括除零運算。
在反向傳播過程中,這可能導致梯度表示式出現問題。例如:
x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])
y = x / div # Results in [inf, 1]
mask = div != 0 # [False, True]
loss = y[mask].sum()
loss.backward()
print(x.grad) # [nan, 1], not [0, 1]
在此示例中,即使我們僅使用了遮蔽後的輸出(其中排除了除零運算),autograd 仍然透過完整的計算圖(包括除零運算)來計算梯度。這會導致遮蔽元素的梯度為 nan,從而可能導致訓練不穩定。
為了避免此問題,有幾種推薦的方法:
在除法之前遮蔽
x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])
mask = div != 0
safe = torch.zeros_like(x)
safe[mask] = x[mask] / div[mask]
loss = safe.sum()
loss.backward() # Produces safe gradients [0, 1]
使用 MaskedTensor(實驗性 API)
from torch.masked import as_masked_tensor
x = torch.tensor([1., 1.], requires_grad=True)
div = torch.tensor([0., 1.])
y = x / div
mask = div != 0
loss = as_masked_tensor(y, mask).sum()
loss.backward() # Cleanly handles "undefined" vs "zero" gradients
關鍵原則是阻止除零運算被記錄在計算圖中,而不是事後遮蔽其結果。這確保了 autograd 只會計算有效運算的梯度。
當處理可能產生 inf 或 nan 值的操作時,這一點很重要,因為遮蔽輸出並不能阻止產生有問題的梯度。
區域性停用梯度計算#
Python 提供了幾種機制來區域性停用梯度計算:
要停用程式碼塊的整體梯度,可以使用 no-grad 模式和 inference 模式等上下文管理器。要更精細地排除子圖的梯度計算,可以設定張量的 requires_grad 欄位。
下面,除了討論上述機制外,我們還將描述 evaluation 模式(nn.Module.eval()),該方法不用於停用梯度計算,但由於其名稱,經常與這三種模式混淆。
設定 requires_grad#
requires_grad 是一個標誌,預設為 false,*除非包裝在* nn.Parameter *中*,它允許對梯度計算進行子圖的精細排除。它在前向和後向傳播中都生效。
在前向傳播過程中,只有當至少一個輸入張量需要 grad 時,操作才會被記錄在後向圖中。在後向傳播(.backward())過程中,只有 requires_grad=True 的葉張量才會在其 .grad 欄位中累積梯度。
需要注意的是,儘管每個張量都有此標誌,但*設定*它僅對葉張量(即沒有 grad_fn 的張量,例如 nn.Module 的引數)有意義。非葉張量(即有 grad_fn 的張量)是與它們相關的後向圖的張量。因此,它們的梯度將作為中間結果,以計算需要 grad 的葉張量的梯度。從這個定義可以看出,所有非葉張量都將自動具有 requires_grad=True。
設定 requires_grad 應該是您控制模型哪些部分參與梯度計算的主要方式,例如,如果您需要在模型微調期間凍結預訓練模型的某些部分。
要凍結模型的一部分,只需對您不希望更新的引數應用 .requires_grad_(False)。如上所述,由於使用這些引數作為輸入的計算在前向傳播中不會被記錄,因此它們在後向傳播中不會更新其 .grad 欄位,因為它們根本不會成為後向圖的一部分,這正是所期望的。
由於這是一個非常常見的模式,requires_grad 也可以在模組級別使用 nn.Module.requires_grad_() 進行設定。當應用於模組時,.requires_grad_() 會影響該模組的所有引數(這些引數預設具有 requires_grad=True)。
Grad 模式#
除了設定 requires_grad 之外,還可以從 Python 中選擇三種 grad 模式,它們可以影響 autograd 在 PyTorch 中如何處理計算:預設模式(grad 模式)、no-grad 模式和 inference 模式,這些都可以透過上下文管理器和裝飾器進行切換。
模式 |
排除操作不被記錄在後向圖中 |
跳過額外的 autograd 跟蹤開銷 |
在模式啟用時建立的張量稍後可以在 grad 模式下使用 |
示例 |
|---|---|---|---|---|
預設 |
✓ |
前向傳播 |
||
|
✓ |
✓ |
最佳化器更新 |
|
|
✓ |
✓ |
資料處理、模型評估 |
預設模式(Grad 模式)#
“預設模式”是指當我們沒有啟用 no-grad 和 inference 模式時所處的模式。與“no-grad 模式”相對,“預設模式”有時也稱為“grad 模式”。
關於預設模式最重要的一點是,它是唯一一種 requires_grad 生效的模式。在另外兩種模式下,requires_grad 總是被重寫為 False。
no-grad 模式#
在 no-grad 模式下的計算,其行為就好像沒有輸入需要 grad 一樣。換句話說,即使輸入具有 require_grad=True,在 no-grad 模式下的計算也永遠不會被記錄在後向圖中。
當您需要執行不應被 autograd 記錄的操作,但仍希望稍後在 grad 模式下使用這些計算的輸出時,可以使用 no-grad 模式。此上下文管理器方便您為程式碼塊或函式停用梯度,而無需臨時將張量設定為 requires_grad=False,然後再改回 True。
例如,在編寫最佳化器時,no-grad 模式可能很有用:在執行訓練更新時,您希望就地更新引數,而不希望此次更新被 autograd 記錄。您還打算在下一個前向傳播中使用更新後的引數進行計算。
torch.nn.init 中的實現也依賴於 no-grad 模式來初始化引數,以避免在就地更新初始化引數時被 autograd 跟蹤。
推理模式#
推理模式是 no-grad 模式的極端版本。與 no-grad 模式一樣,推理模式下的計算不會被記錄在後向圖中,但啟用推理模式可以使 PyTorch 更快地加速您的模型。這種更好的執行時帶來了缺點:在推理模式下建立的張量在退出推理模式後將無法用於 autograd 記錄的計算。
當您執行與 autograd 沒有互動的計算,並且*不*打算稍後在 autograd 記錄的任何計算中使用在推理模式下建立的張量時,啟用推理模式。
建議您嘗試在不需要 autograd 跟蹤的程式碼部分(例如資料處理和模型評估)中使用推理模式。如果它開箱即用,那將是免費的效能提升。如果您在啟用推理模式後遇到錯誤,請檢查您是否在退出推理模式後,在 autograd 記錄的計算中使用了在推理模式下建立的張量。如果您的場景無法避免此類使用,您可以隨時切換回 no-grad 模式。
有關推理模式的詳細資訊,請參閱 推理模式。
有關推理模式實現細節,請參閱 RFC-0011-InferenceMode。
評估模式(nn.Module.eval())#
評估模式不是一種區域性停用梯度計算的機制。它之所以在此處包含,是因為它有時會被誤認為是這樣的機制。
從功能上看,module.eval()(或等效的 module.train(False))與 no-grad 模式和 inference 模式完全正交。 model.eval() 如何影響您的模型完全取決於模型中使用的特定模組,以及它們是否定義了任何特定於訓練模式的行為。
您負責在模型依賴於 torch.nn.Dropout 和 torch.nn.BatchNorm2d 等模組時呼叫 model.eval() 和 model.train(),這些模組在訓練模式下可能表現不同,例如,為了避免在驗證資料上更新 BatchNorm 的執行統計資訊。
建議您在訓練時始終使用 model.train(),在評估模型(驗證/測試)時使用 model.eval(),即使您不確定模型是否具有特定於訓練模式的行為,因為您使用的模組可能會更新為在訓練和評估模式下具有不同的行為。
Autograd 的就地操作#
在 autograd 中支援就地操作是一個棘手的問題,我們不鼓勵在大多數情況下使用它們。Autograd 的激進緩衝區釋放和重用使其非常高效,很少有情況可以使就地操作顯著降低記憶體使用量。除非您面臨巨大的記憶體壓力,否則您可能永遠不需要使用它們。
限制就地操作適用性的主要原因有兩個:
就地操作可能會覆蓋計算梯度所需的值。
每個就地操作都需要實現來重寫計算圖。原地操作只是分配新物件並保留對舊圖的引用,而就地操作需要將所有輸入的建立者更改為代表此操作的
Function。這可能很棘手,特別是當有許多張量引用同一儲存(例如透過索引或轉置建立)時。如果修改輸入的儲存被任何其他Tensor引用,就地函式將引發錯誤。
就地正確性檢查#
每個張量都維護一個版本計數器,該計數器在每次在任何操作中被標記為髒時遞增。當 Function 為後向傳播儲存任何張量時,還會儲存其包含張量的版本計數器。一旦您訪問 self.saved_tensors,就會進行檢查,如果大於儲存的值,則會引發錯誤。這確保瞭如果您使用就地函式且未看到任何錯誤,您可以確信計算出的梯度是正確的。
多執行緒 Autograd#
autograd 引擎負責執行計算後向傳播所需的所有後向操作。本節將詳細介紹所有有助於您在多執行緒環境中充分利用它的細節。(這僅適用於 PyTorch 1.6+,因為早期版本的行為有所不同。)
使用者可以使用多執行緒程式碼(例如 Hogwild 訓練)來訓練他們的模型,並且不會阻塞併發的後向計算。示例如下:
# Define a train function to be used in different threads
def train_fn():
x = torch.ones(5, 5, requires_grad=True)
# forward
y = (x + 3) * (x + 4) * 0.5
# backward
y.sum().backward()
# potential optimizer update
# User write their own threading code to drive the train_fn
threads = []
for _ in range(10):
p = threading.Thread(target=train_fn, args=())
p.start()
threads.append(p)
for p in threads:
p.join()
請注意,使用者應該意識到一些行為:
CPU 併發#
當您透過 Python 或 C++ API 在 CPU 上的多個執行緒中執行 backward() 或 grad() 時,您期望看到額外的併發,而不是在執行期間將所有後向呼叫序列化(PyTorch 1.6 之前的行為)。
非確定性#
如果您從多個執行緒併發呼叫 backward() 並且有共享輸入(例如 Hogwild CPU 訓練),那麼應該預期非確定性。這可能會發生,因為引數會自動線上程之間共享,因此多個執行緒可能在梯度累積期間訪問並嘗試累積相同的 .grad 屬性。這在技術上是不安全的,它可能導致競態條件,結果可能無法使用。
開發具有共享引數的多執行緒模型的使用者應牢記執行緒模型,並理解上述問題。
可以使用函式式 API torch.autograd.grad() 來計算梯度,而不是 backward(),以避免非確定性。
圖保留#
如果 autograd 圖的一部分線上程之間共享(即,單執行緒執行前向圖的第一部分,然後在多執行緒中執行第二部分),則圖的第一部分被共享。在這種情況下,不同執行緒在同一圖上執行 grad() 或 backward() 可能會在其中一個執行緒的執行過程中銷燬圖,導致另一個執行緒崩潰。Autograd 會像兩次呼叫 backward() 而沒有 retain_graph=True 一樣向用戶丟擲錯誤,並告知使用者他們應該使用 retain_graph=True。
Autograd 節點上的執行緒安全#
由於 Autograd 允許呼叫執行緒驅動其後向執行以實現潛在的並行化,因此確保 CPU 上的執行緒安全非常重要,特別是在共享部分/全部 GraphTask 的並行 backward() 呼叫時。
自定義 Python autograd.Function 由於 GIL 的存在,會自動成為執行緒安全的。對於內建的 C++ Autograd 節點(例如 AccumulateGrad, CopySlices)和自定義 autograd::Function,Autograd 引擎使用執行緒互斥鎖來確保可能具有狀態讀寫操作的 autograd 節點上的執行緒安全。
C++ 鉤子上不存線上程安全#
Autograd 依賴使用者編寫執行緒安全的 C++ 鉤子。如果您希望鉤子在多執行緒環境中正確應用,您需要編寫適當的執行緒鎖定程式碼來確保鉤子是執行緒安全的。
複數 Autograd#
簡而言之:
當您使用 PyTorch 對任何具有複數域和/或值域的函式 進行微分時,梯度是根據函式是更大實值損失函式 的一部分來計算的。計算出的梯度是 (注意 z 的共軛),其負值正是梯度下降演算法使用的最陡下降方向。因此,使現有最佳化器能夠直接與複數引數協同工作的路徑是可行的。
此約定與 TensorFlow 的複數微分約定一致,但與 JAX 不同(JAX 計算 )。
如果您有一個內部使用複數運算的實值實值函式,那麼此約定無關緊要:您將始終獲得如果僅使用實數運算實現該函式所得到的結果。
如果您對數學細節感興趣,或者想知道如何在 PyTorch 中定義複數導數,請繼續閱讀。
什麼是複數導數?#
複數可微的數學定義是對導數的極限定義進行泛化,使其能夠處理複數。考慮一個函式 ,
其中 和 是兩個變數實值函式, 是虛數單位。
使用導數定義,我們可以寫出:
為了使這個極限存在,不僅 和 必須是實可微的,而且 還必須滿足柯西-黎曼方程。換句話說:真實和虛部步驟的極限()必須相等。這是一個更嚴格的條件。
複數可微函式通常被稱為全純函式。它們非常規整,具有您從實值可微函式中學到的所有良好特性,但在最佳化世界中幾乎沒有用處。對於最佳化問題,研究界通常只使用實值目標函式,因為複數不屬於任何有序域,因此具有複數值損失的意義不大。
事實證明,沒有任何有趣的實值目標函式滿足柯西-黎曼方程。因此,全純函式的理論不能用於最佳化,因此大多數人使用維爾廷格演算。
維爾廷格演算出現於...#
因此,我們擁有這套出色的複數可微性和全純函數理論,而我們卻無法利用它,因為許多常用函式並非全純。可憐的數學家該怎麼辦?維爾廷格觀察到,即使 不是全純的,也可以將其重寫為雙變數函式 ,該函式總是全純的。這是因為 分量的實部和虛部可以表示為 和 的形式:
維爾廷格演算建議研究 ,如果 是實可微的,則保證是全純的(另一種思考方式是將其視為座標系變換,從 到 。
從上述方程中,我們得到:
這是您會在Wikipedia上找到的經典的Wirtinger微積分定義。
這種改變帶來了許多優美的推論。
例如,柯西-黎曼方程可以被簡化為僅僅說明 (也就是說,函式可以完全用來表示,而不需要引用).
另一個重要的(有時也是反直覺的)結果是,正如我們稍後將看到的,當我們對實值損失函式進行最佳化時,在進行變數更新時應採取的步驟由(而不是)。
欲瞭解更多資訊,請參閱:https://arxiv.org/pdf/0906.4835.pdf
Wirtinger微積分在最佳化中有何用處?#
音訊和其他領域的研究人員更常使用梯度下降來最佳化具有複雜變數的實值損失函式。通常,這些人將實部和虛部視為可以更新的獨立通道。對於步長和損失,我們可以寫出以下在中的方程:
這些方程如何轉換到複數空間?
發生了一件非常有意思的事情:Wirtinger微積分告訴我們,可以將上面的複數變數更新公式簡化為僅引用共軛Wirtinger導數,這給了我們最佳化中採取的準確步驟。
由於共軛Wirtinger導數給出了實值損失函式的準確最佳化步驟,PyTorch在對具有實值損失的函式進行微分時,會返回該導數。
PyTorch如何計算共軛Wirtinger導數?#
通常,我們的導數公式以grad_output作為輸入,它表示已經計算過的傳入的Vector-Jacobian乘積,即,其中是整個計算(產生實值損失)的損失,而是我們函式的輸出。目標是計算,其中是函式的輸入。實際上,在實值損失的情況下,我們只需計算,儘管鏈式法則暗示我們還需要訪問。如果您想跳過此推導,請檢視本節的最後一個方程,然後跳到下一節。
讓我們繼續使用進行討論,定義為。如上所述,autograd 的梯度約定側重於實值損失函式的最佳化,因此我們假設是更大的實值損失函式的一部分。使用鏈式法則,我們可以寫出:
(1)#
現在使用Wirtinger導數的定義,我們可以寫出:
這裡應該指出的是,由於和是實函式,並且根據我們假設是實值函式的一部分,是實數,因此我們有:
(2)#
即, 等於。
透過求解上述關於和,我們得到:
(3)#
使用 公式 (2),我們得到
(4)#
最後一個方程對於編寫你自己的梯度很重要,因為它將我們的導數公式分解為一個易於手工計算的更簡單的公式。
我該如何寫一個複數函式的導數公式?#
上述帶框的方程給出了所有複數函式導數的通用公式。然而,我們仍然需要計算 和 。有兩種方法可以做到這一點:
第一種方法是直接使用 Wirtinger 導數的定義來計算 和 (使用 和 (可以按常規方式計算)。
第二種方法是使用變數替換技巧,將 重寫為一個二元函式 ,並透過將 和 視為獨立變數來計算共軛 Wirtinger 導數。這通常更容易;例如,如果所討論的函式是全純的,則只會用到 (而 將為零)。
讓我們以 作為示例,其中 。
使用第一種方法計算 Wirtinger 導數,我們得到:
使用 公式 (4),以及 grad_output = 1.0 (這是 PyTorch 中呼叫 backward() 時標量輸出的預設梯度輸出值),我們得到
使用第二種方法計算 Wirtinger 導數,我們直接得到
And using (4) again, we get . As you can see, the second way involves lesser calculations, and comes in more handy for faster calculations.
Hooks for saved tensors#
You can control how saved tensors are packed / unpacked by defining a pair of pack_hook / unpack_hook hooks. The pack_hook function should take a tensor as its single argument but can return any python object (e.g. another tensor, a tuple, or even a string containing a filename). The unpack_hook function takes as its single argument the output of pack_hook and should return a tensor to be used in the backward pass. The tensor returned by unpack_hook only needs to have the same content as the tensor passed as input to pack_hook. In particular, any autograd-related metadata can be ignored as they will be overwritten during unpacking.
An example of such pair is
class SelfDeletingTempFile():
def __init__(self):
self.name = os.path.join(tmp_dir, str(uuid.uuid4()))
def __del__(self):
os.remove(self.name)
def pack_hook(tensor):
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
def unpack_hook(temp_file):
return torch.load(temp_file.name)
Notice that the unpack_hook should not delete the temporary file because it might be called multiple times: the temporary file should be alive for as long as the returned SelfDeletingTempFile object is alive. In the above example, we prevent leaking the temporary file by closing it when it is no longer needed (on deletion of the SelfDeletingTempFile object).
注意
We guarantee that pack_hook will only be called once but unpack_hook can be called as many times as the backward pass requires it and we expect it to return the same data each time.
警告
Performing inplace operations on the input of any of the functions is forbidden as they may lead to unexpected side-effects. PyTorch will throw an error if the input to a pack hook is modified inplace but does not catch the case where the input to an unpack hook is modified inplace.
Registering hooks for a saved tensor#
You can register a pair of hooks on a saved tensor by calling the register_hooks() method on a SavedTensor object. Those objects are exposed as attributes of a grad_fn and start with the _raw_saved_ prefix.
x = torch.randn(5, requires_grad=True)
y = x.pow(2)
y.grad_fn._raw_saved_self.register_hooks(pack_hook, unpack_hook)
The pack_hook method is called as soon as the pair is registered. The unpack_hook method is called each time the saved tensor needs to be accessed, either by means of y.grad_fn._saved_self or during the backward pass.
警告
If you maintain a reference to a SavedTensor after the saved tensors have been released (i.e. after backward has been called), calling its register_hooks() is forbidden. PyTorch will throw an error most of the time but it may fail to do so in some cases and undefined behavior may arise.
Registering default hooks for saved tensors#
Alternatively, you can use the context-manager saved_tensors_hooks to register a pair of hooks which will be applied to all saved tensors that are created in that context.
示例
# Only save on disk tensors that have size >= 1000
SAVE_ON_DISK_THRESHOLD = 1000
def pack_hook(x):
if x.numel() < SAVE_ON_DISK_THRESHOLD:
return x.detach()
temp_file = SelfDeletingTempFile()
torch.save(tensor, temp_file.name)
return temp_file
def unpack_hook(tensor_or_sctf):
if isinstance(tensor_or_sctf, torch.Tensor):
return tensor_or_sctf
return torch.load(tensor_or_sctf.name)
class Model(nn.Module):
def forward(self, x):
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
# ... compute output
output = x
return output
model = Model()
net = nn.DataParallel(model)
The hooks defined with this context manager are thread-local. Hence, the following code will not produce the desired effects because the hooks do not go through DataParallel.
# Example what NOT to do
net = nn.DataParallel(model)
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
output = net(input)
Note that using those hooks disables all the optimization in place to reduce Tensor object creation. For example
with torch.autograd.graph.saved_tensors_hooks(lambda x: x.detach(), lambda x: x):
x = torch.randn(5, requires_grad=True)
y = x * x
Without the hooks, x, y.grad_fn._saved_self and y.grad_fn._saved_other all refer to the same tensor object. With the hooks, PyTorch will pack and unpack x into two new tensor objects that share the same storage with the original x (no copy performed).
Backward Hooks execution#
This section will discuss when different hooks fire or don’t fire. Then it will discuss the order in which they are fired. The hooks that will be covered are: backward hooks registered to Tensor via torch.Tensor.register_hook(), post-accumulate-grad hooks registered to Tensor via torch.Tensor.register_post_accumulate_grad_hook(), post-hooks registered to Node via torch.autograd.graph.Node.register_hook(), and pre-hooks registered to Node via torch.autograd.graph.Node.register_prehook().
Whether a particular hook will be fired#
Hooks registered to a Tensor via torch.Tensor.register_hook() are executed when gradients are being computed for that Tensor. (Note that this does not require the Tensor’s grad_fn to be executed. For example, if the Tensor is passed as part of the inputs argument to torch.autograd.grad(), the Tensor’s grad_fn may not be executed, but the hook register to that Tensor will always be executed.)
Hooks registered to a Tensor via torch.Tensor.register_post_accumulate_grad_hook() are executed after the gradients have been accumulated for that Tensor, meaning the Tensor’s grad field has been set. Whereas hooks registered via torch.Tensor.register_hook() are run as gradients are being computed, hooks registered via torch.Tensor.register_post_accumulate_grad_hook() are only triggered once the Tensor’s grad field is updated by autograd at the end of the backward pass. Thus, post-accumulate-grad hooks can only be registered for leaf Tensors. Registering a hook via torch.Tensor.register_post_accumulate_grad_hook() on a non-leaf Tensor will error, even if you call backward(retain_graph=True).
Hooks registered to torch.autograd.graph.Node using torch.autograd.graph.Node.register_hook() or torch.autograd.graph.Node.register_prehook() are only fired if the Node it was registered to is executed.
Whether a particular Node is executed may depend on whether the backward pass was called with torch.autograd.grad() or torch.autograd.backward(). Specifically, you should be aware of these differences when you register a hook on a Node corresponding to a Tensor that you are passing to torch.autograd.grad() or torch.autograd.backward() as part of the inputs argument.
If you are using torch.autograd.backward(), all of the above mentioned hooks will be executed, whether or not you specified the inputs argument. This is because .backward() executes all Nodes, even if they correspond to a Tensor specified as an input. (Note that the execution of this additional Node corresponding to Tensors passed as inputs is usually unnecessary, but done anyway. This behavior is subject to change; you should not depend on it.)
On the other hand, if you are using torch.autograd.grad(), the backward hooks registered to Nodes that correspond to the Tensors passed to input may not be executed, because those Nodes will not be executed unless there is another input that depends on the gradient result of this Node.
The order in which the different hooks are fired#
The order in which things happen are
hooks registered to Tensor are executed
pre-hooks registered to Node are executed (if Node is executed).
the
.gradfield is updated for Tensors that retain_gradNode is executed (subject to rules above)
for leaf Tensors that have
.gradaccumulated, post-accumulate-grad hooks are executedpost-hooks registered to Node are executed (if Node is executed)
If multiple hooks of the same type are registered on the same Tensor or Node they are executed in the order in which they are registered. Hooks that are executed later can observe the modifications to the gradient made by earlier hooks.
Special hooks#
torch.autograd.graph.register_multi_grad_hook() is implemented using hooks registered to Tensors. Each individual Tensor hook is fired following the Tensor hook ordering defined above and the registered multi-grad hook is called when the last Tensor gradient is computed.
torch.nn.modules.module.register_module_full_backward_hook() is implemented using hooks registered to Node. As the forward is computed, hooks are registered to grad_fn corresponding to the inputs and outputs of the module. Because a module may take multiple inputs and return multiple outputs, a dummy custom autograd Function is first applied to the inputs of the module before forward and the outputs of the module before the output of forward is returned to ensure that those Tensors share a single grad_fn, which we can then attach our hooks to.
Behavior of Tensor hooks when Tensor is modified in-place#
Usually hooks registered to a Tensor receive the gradient of the outputs with respect to that Tensor, where the value of the Tensor is taken to be its value at the time backward is computed.
However, if you register hooks to a Tensor, and then modify that Tensor in-place, hooks registered before in-place modification similarly receive gradients of the outputs with respect to the Tensor, but the value of the Tensor is taken to be its value before in-place modification.
If you prefer the behavior in the former case, you should register them to the Tensor after all in-place modifications to it have been made. For example
t = torch.tensor(1., requires_grad=True).sin()
t.cos_()
t.register_hook(fn)
t.backward()
Furthermore, it can be helpful to know that under the hood, when hooks are registered to a Tensor, they actually become permanently bound to the grad_fn of that Tensor, so if that Tensor is then modified in-place, even though the Tensor now has a new grad_fn, hooks registered before it was modified in-place will continue to be associated with the old grad_fn, e.g. they will fire when that Tensor’s old grad_fn is reached in the graph by the autograd engine.