評價此頁

Gradcheck 機制#

建立日期: 2021年4月27日 | 最後更新日期: 2025年6月18日

本文件概述了 gradcheck()gradgradcheck() 函式的工作原理。

本文將涵蓋實值和復值函式的正向和反向模式自動微分,以及高階導數。本文件還將涵蓋 gradcheck 的預設行為以及傳遞 fast_mode=True 引數(下文稱為快速 gradcheck)的情況。

符號和背景資訊#

在本文件中,我們將使用以下約定

  1. xx, yy, aa, bb, vv, uu, ururuiui 是實值向量,而 zz 是一個復值向量,可以表示為兩個實值向量的形式:z=a+ibz = a + i b

  2. NNMM 分別是我們用於輸入和輸出空間的兩個整數。

  3. f:RNRMf: \mathcal{R}^N \to \mathcal{R}^M 是我們的基本實到實函式,其中 y=f(x)y = f(x)

  4. g:CNRMg: \mathcal{C}^N \to \mathcal{R}^M 是我們的基本復到實函式,其中 y=g(z)y = g(z)

對於簡單的實到實函式,我們將其雅可比矩陣表示為 JfJ_f,大小為 M×NM \times N。此矩陣包含所有偏導數,其中位置 (i,j)(i, j) 的項是 yixj\frac{\partial y_i}{\partial x_j}. 反向模式自動微分則計算,對於給定的向量 vv,大小為 MM 的量 vTJfv^T J_f. 另一方面,正向模式自動微分計算,對於給定的向量 uu,大小為 NN 的量 JfuJ_f u

對於包含複數值的函式,情況要複雜得多。此處僅提供概要,完整描述請參見 複數自動微分

為了滿足複數可微性(柯西-黎曼方程)的約束條件,對所有實值損失函式來說,這些約束條件都過於嚴格,因此我們採用了維爾丁格演算。在維爾丁格演算的基本設定中,鏈式法則需要同時訪問維爾丁格導數(下文稱為 WW)和共軛維爾丁格導數(下文稱為 CWCW)。WWCWCW 都需要傳播,因為通常情況下,儘管有名稱,一個並不是另一個的複共軛。

為了避免傳播這兩個值,對於反向模式自動微分,我們始終假定正在計算導數的函式要麼是實值函式,要麼是更大的實值函式的一部分。此假設意味著我們在反向傳播過程中計算的所有中間梯度也與實值函式相關聯。實際上,此假設在進行最佳化時並無限制,因為此類問題需要實值目標(因為複數沒有自然順序)。

在此假設下,使用 WWCWCW 的定義,我們可以證明 W=CWW = CW^* (我們在此使用 * 表示複共軛),因此只需要“反向傳播透過圖”其中一個值,而另一個可以輕鬆恢復。為了簡化內部計算,PyTorch 使用 2CW2 * CW 作為其反向傳播並返回值,當用戶請求梯度時。與實值情況類似,當輸出實際在 RM\mathcal{R}^M 時,反向模式自動微分不會計算 2CW2 * CW,而僅計算 vT(2CW)v^T (2 * CW),其中 vRMv \in \mathcal{R}^M 是給定的向量。

對於正向模式自動微分,我們採用類似的邏輯,在這種情況下,我們假設該函式是更大函式的一部分,該函式的輸入在 R\mathcal{R} 中。在此假設下,我們可以得出類似的結論,即每個中間結果都對應一個輸入在 R\mathcal{R} 中的函式,並且在這種情況下,使用 WWCWCW 的定義,我們可以證明對於中間函式 W=CWW = CW。為了確保正向和反向模式在單變數函式的基本情況下計算相同的量,正向模式也計算 2CW2 * CW。與實值情況類似,當輸入實際在 RN\mathcal{R}^N 中時,正向模式自動微分不計算 2CW2 * CW,而是僅計算 (2CW)u(2 * CW) u,其中 uRNu \in \mathcal{R}^N 是給定的向量。

預設反向模式 gradcheck 行為#

實到實函式#

為了測試函式 f:RNRM,xyf: \mathcal{R}^N \to \mathcal{R}^M, x \to y,我們透過兩種方式重構完整的雅可比矩陣 JfJ_f,大小為 M×NM \times N:一種是解析方法,另一種是數值方法。解析方法使用我們的反向模式自動微分,而數值方法使用有限差分。然後逐個元素地比較兩個重構的雅可比矩陣是否相等。

預設實輸入數值評估#

如果我們考慮一維函式(N=M=1N = M = 1)的基本情況,那麼我們可以使用 維基百科文章 中的基本有限差分公式。我們使用“中心差分”以獲得更好的數值特性。

yxf(x+eps)f(xeps)2eps\frac{\partial y}{\partial x} \approx \frac{f(x + eps) - f(x - eps)}{2 * eps}

該公式易於推廣到多輸出(M>1M \gt 1),其中 yx\frac{\partial y}{\partial x} 是大小為 M×1M \times 1 的列向量,例如 f(x+eps)f(x + eps)。在這種情況下,上述公式可以原樣重用,並且只需對使用者函式進行兩次評估(即 f(x+eps)f(x + eps)f(xeps)f(x - eps))就可以近似整個雅可比矩陣。

處理多輸入(N>1N \gt 1)的情況計算成本更高。在這種情況下,我們逐個迴圈遍歷所有輸入,並對 xx 的每個元素依次應用 epseps 擾動。這允許我們逐列重構 JfJ_f 矩陣。

預設實輸入解析評估#

對於解析評估,我們利用上面所述的事實,即反向模式自動微分計算 vTJfv^T J_f. 對於只有一個輸出的函式,我們直接使用 v=1v = 1 來透過一次反向傳播恢復整個雅可比矩陣。

對於有多個輸出的函式,我們採用一個 for 迴圈,該迴圈遍歷每個輸出,其中每個 vv 是一個對應於每個輸出的獨熱向量。這允許我們逐行重構 JfJ_f 矩陣。

復到實函式#

為了測試函式 g:CNRM,zyg: \mathcal{C}^N \to \mathcal{R}^M, z \to y,其中 z=a+ibz = a + i b,我們重構包含 2CW2 * CW 的(復值)矩陣。

預設複數輸入數值評估#

首先考慮 N=M=1N = M = 1 的基本情況。我們從 這篇研究論文(第 3 章)得知:

CW:=yz=12(ya+iyb)CW := \frac{\partial y}{\partial z^*} = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})

請注意,在上述等式中,ya\frac{\partial y}{\partial a}yb\frac{\partial y}{\partial b}RR\mathcal{R} \to \mathcal{R} 導數。為了在數值上計算它們,我們使用上面為實到實情況描述的方法。這允許我們計算 CWCW 矩陣,然後將其乘以 22

請注意,截至撰寫本文時,程式碼以一種略顯迂迴的方式計算此值。

# Code from https://github.com/pytorch/pytorch/blob/58eb23378f2a376565a66ac32c93a316c45b6131/torch/autograd/gradcheck.py#L99-L105
# Notation changes in this code block:
# s here is y above
# x, y here are a, b above

ds_dx = compute_gradient(eps)
ds_dy = compute_gradient(eps * 1j)
# conjugate wirtinger derivative
conj_w_d = 0.5 * (ds_dx + ds_dy * 1j)
# wirtinger derivative
w_d = 0.5 * (ds_dx - ds_dy * 1j)
d[d_idx] = grad_out.conjugate() * conj_w_d + grad_out * w_d.conj()

# Since grad_out is always 1, and W and CW are complex conjugate of each other, the last line ends up computing exactly `conj_w_d + w_d.conj() = conj_w_d + conj_w_d = 2 * conj_w_d`.

預設複數輸入解析評估#

由於反向模式自動微分已經精確計算了 CWCW 的兩倍,因此我們在這裡使用了與實到實情況相同的技巧,當有多個實輸出時,我們逐行重構矩陣。

具有複數輸出的函式#

在這種情況下,使用者提供的函式不遵循自動微分的假設,即我們計算反向模式自動微分的函式是實值的。這意味著直接在函式上使用自動微分沒有明確定義。為了解決這個問題,我們將測試函式 h:PNCMh: \mathcal{P}^N \to \mathcal{C}^M (其中 P\mathcal{P} 可以是 R\mathcal{R}C\mathcal{C})替換為兩個函式:hrhrhihi,使得

我們定義了以下函式:

其中 qPq \in \mathcal{P}。然後,我們將根據上述實到實或復到實的情況,對 hrhrhihi 進行基本的梯度檢驗,具體取決於 P\mathcal{P}

請注意,截至撰寫本文時,程式碼並未顯式建立這些函式,而是透過將 realrealimagimag 引數手動傳遞給不同的函式來實現鏈式法則。當 grad_out=1\text{grad\_out} = 1 時,我們考慮 hrhr。當 grad_out=1j\text{grad\_out} = 1j 時,我們考慮 hihi

快速反向模式梯度檢驗#

雖然上述梯度檢驗的表述非常有用,可以確保正確性和可除錯性,但它非常慢,因為它會重建完整的雅可比矩陣。本節介紹了一種執行梯度檢驗的更快方法,同時不影響其正確性。透過在檢測到錯誤時新增特殊邏輯,可以恢復可除錯性。在這種情況下,我們可以執行預設版本,該版本會重建完整的矩陣,以便向用戶提供完整的詳細資訊。

這裡的總體策略是找到一個標量量,該標量量可以透過數值和解析方法高效計算,並且能夠充分代表緩慢梯度檢驗計算的完整矩陣,從而確保它能夠捕獲雅可比矩陣中的任何差異。

實到實函式的快速梯度檢驗#

我們想在這裡計算的標量量是 vTJfuv^T J_f u,對於給定的隨機向量 vRMv \in \mathcal{R}^M 和隨機單位範數向量 uRNu \in \mathcal{R}^N

對於數值評估,我們可以高效地計算:

Jfuf(x+ueps)f(xueps)2eps.J_f u \approx \frac{f(x + u * eps) - f(x - u * eps)}{2 * eps}.

然後,我們計算此向量與 vv 的點積,以獲得感興趣的標量值。

對於解析版本,我們可以使用反向模式自動微分來直接計算 vTJfv^T J_f。然後,我們將其與 uu 進行點積以獲得期望值。

復到實函式的快速梯度檢驗#

與實到實情況類似,我們要對完整矩陣進行約簡。但是,2CW2 * CW 矩陣是複數值的,因此在這種情況下,我們將比較復標量。

由於在數值情況下高效計算存在一些限制,併為了儘量減少數值評估的次數,我們計算以下(儘管可能令人驚訝的)標量值:

s:=2vT(real(CW)ur+iimag(CW)ui)s := 2 * v^T (real(CW) ur + i * imag(CW) ui)

其中 vRMv \in \mathcal{R}^MurRNur \in \mathcal{R}^NuiRNui \in \mathcal{R}^N

快速複數輸入數值評估#

我們首先考慮如何透過數值方法計算 ss。為此,請記住我們考慮的是 g:CNRM,zyg: \mathcal{C}^N \to \mathcal{R}^M, z \to y,其中 z=a+ibz = a + i b,並且 CW=12(ya+iyb)CW = \frac{1}{2} * (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b}),我們將其重寫為:

s=2vT(real(CW)ur+iimag(CW)ui)=2vT(12yaur+i12ybui)=vT(yaur+iybui)=vT((yaur)+i(ybui))\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= 2 * v^T (\frac{1}{2} * \frac{\partial y}{\partial a} ur + i * \frac{1}{2} * \frac{\partial y}{\partial b} ui) \\ &= v^T (\frac{\partial y}{\partial a} ur + i * \frac{\partial y}{\partial b} ui) \\ &= v^T ((\frac{\partial y}{\partial a} ur) + i * (\frac{\partial y}{\partial b} ui)) \end{aligned}

在上面的公式中,我們可以看到 yaur\frac{\partial y}{\partial a} urybui\frac{\partial y}{\partial b} ui 可以像實到實情況的快速版本一樣進行評估。一旦計算出這些實值量,我們就可以重建右側的復向量,並與實值 vv 向量進行點積。

快速複數輸入解析評估#

在解析情況下,事情會更簡單,我們重寫公式為:

s=2vT(real(CW)ur+iimag(CW)ui)=vTreal(2CW)ur+ivTimag(2CW)ui)=real(vT(2CW))ur+iimag(vT(2CW))ui\begin{aligned} s &= 2 * v^T (real(CW) ur + i * imag(CW) ui) \\ &= v^T real(2 * CW) ur + i * v^T imag(2 * CW) ui) \\ &= real(v^T (2 * CW)) ur + i * imag(v^T (2 * CW)) ui \end{aligned}

因此,我們可以利用反向模式自動微分提供一種高效計算 vT(2CW)v^T (2 * CW) 的方法,然後將其實部與 urur 進行點積,虛部與 uiui 進行點積,最後重構出最終的復標量 ss

為什麼不使用複數 uu#

此時,您可能會想,為什麼我們不選擇一個複數 uu 並直接執行約簡 2vTCWu2 * v^T CW u'. 為了深入探討這一點,在本段中,我們將使用複數版本的 uu,記為 u=ur+iuiu' = ur' + i ui'. 使用這樣的複數 uu',問題在於進行數值評估時,我們需要計算

2CWu=(ya+iyb)(ur+iui)=yaur+iyaui+iyburybui\begin{aligned} 2*CW u' &= (\frac{\partial y}{\partial a} + i \frac{\partial y}{\partial b})(ur' + i ui') \\ &= \frac{\partial y}{\partial a} ur' + i \frac{\partial y}{\partial a} ui' + i \frac{\partial y}{\partial b} ur' - \frac{\partial y}{\partial b} ui' \end{aligned}

這就需要四次實到實有限差分評估(是上述方法的兩倍)。由於這種方法沒有更多的自由度(變數數量相同),並且我們試圖在此處實現最快的評估,因此我們使用了上述另一種表述。

具有複數輸出的函式的快速梯度檢驗#

與慢速情況一樣,我們考慮兩個實值函式,併為每個函式使用上述適當的規則。

二階梯度檢驗實現#

PyTorch 還提供了一個驗證二階梯度的實用程式。這裡的目標是確保反向實現的微分也是正確的,並且計算結果正確。

此功能透過考慮函式 F:x,vvTJfF: x, v \to v^T J_f and use the gradcheck defined above on this function. Note that vv in this case is just a random vector with the same type as f(x)f(x).

The fast version of gradgradcheck is implemented by using the fast version of gradcheck on that same function FF.