評價此頁

理解 requires_grad、retain_grad、葉子張量和非葉子張量#

作者: Justin Silver

本教程使用一個簡單的示例,解釋了 requires_gradretain_grad、葉子張量和非葉子張量的細微差別。

在開始之前,請確保您理解 張量及其操作方法。對 autograd 的工作原理 的基本瞭解也將很有用。

設定#

首先,請確保已 安裝 PyTorch,然後匯入必要的庫。

import torch
import torch.nn.functional as F

接下來,我們例項化一個簡單的網路來關注梯度。這將是一個仿射層,後跟 ReLU 啟用,最後是預測張量和標籤張量之間的 MSE 損失。

\[\mathbf{y}_{\text{pred}} = \text{ReLU}(\mathbf{x} \mathbf{W} + \mathbf{b})\]
\[L = \text{MSE}(\mathbf{y}_{\text{pred}}, \mathbf{y})\]

請注意,引數(Wb)需要 requires_grad=True,以便 PyTorch 跟蹤涉及這些張量的操作。我們將在未來的 部分 中更詳細地討論這一點。

# tensor setup
x = torch.ones(1, 3)                      # input with shape: (1, 3)
W = torch.ones(3, 2, requires_grad=True)  # weights with shape: (3, 2)
b = torch.ones(1, 2, requires_grad=True)  # bias with shape: (1, 2)
y = torch.ones(1, 2)                      # output with shape: (1, 2)

# forward pass
z = (x @ W) + b                           # pre-activation with shape: (1, 2)
y_pred = F.relu(z)                        # activation with shape: (1, 2)
loss = F.mse_loss(y_pred, y)              # scalar loss

葉子張量與非葉子張量#

在執行正向傳播後,PyTorch autograd 已構建了一個 動態計算圖,如下所示。這是一個 有向無環圖 (DAG),它記錄了輸入張量(葉子節點)、對這些張量的所有後續操作以及中間/輸出張量(非葉子節點)。該圖使用微積分中的 鏈式法則,從圖的根(輸出)到葉子(輸入)計算每個張量的梯度。

\[\mathbf{y} = \mathbf{f}_k\bigl(\mathbf{f}_{k-1}(\dots \mathbf{f}_1(\mathbf{x}) \dots)\bigr)\]
\[\frac{\partial \mathbf{y}}{\partial \mathbf{x}} = \frac{\partial \mathbf{f}_k}{\partial \mathbf{f}_{k-1}} \cdot \frac{\partial \mathbf{f}_{k-1}}{\partial \mathbf{f}_{k-2}} \cdot \cdots \cdot \frac{\partial \mathbf{f}_1}{\partial \mathbf{x}}\]
Computational graph after forward pass

正向傳播後的計算圖#

PyTorch 將一個節點視為葉子,如果它不是至少一個具有 requires_grad=True 的輸入張量運算的結果(例如 xWby),而所有其他節點則被視為非葉子(例如 zy_predloss)。您可以透過檢查張量的 is_leaf 屬性以程式設計方式驗證這一點。

# prints True because new tensors are leafs by convention
print(f"{x.is_leaf=}")

# prints False because tensor is the result of an operation with at
# least one input having requires_grad=True
print(f"{z.is_leaf=}")
x.is_leaf=True
z.is_leaf=False

葉子和非葉子之間的區別決定了在反向傳播後,張量的梯度是否會儲存在其 grad 屬性中,從而可用於 梯度下降。我們將在 下一節 中更詳細地介紹這一點。

現在讓我們研究一下 PyTorch 如何在其計算圖中計算和儲存張量的梯度。

requires_grad#

為了構建可用於梯度計算的計算圖,我們需要將 requires_grad=True 引數傳遞給張量建構函式。預設情況下,值為 False,因此 PyTorch 不會跟蹤建立的任何張量的梯度。要驗證這一點,請嘗試不設定 requires_grad,重新執行正向傳播,然後執行反向傳播。您會看到

>>> loss.backward()
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

此錯誤意味著 autograd 無法反向傳播到任何葉子張量,因為 loss 未跟蹤梯度。如果您需要更改此屬性,可以透過在張量上呼叫 requires_grad_() 來實現(注意後面的 _)。

我們可以像上面使用 is_leaf 屬性一樣,對哪些節點需要梯度計算進行健全性檢查。

print(f"{x.requires_grad=}") # prints False because requires_grad=False by default
print(f"{W.requires_grad=}") # prints True because we set requires_grad=True in constructor
print(f"{z.requires_grad=}") # prints True because tensor is a non-leaf node
x.requires_grad=False
W.requires_grad=True
z.requires_grad=True

需要記住的是,非葉子張量根據定義具有 requires_grad=True,否則反向傳播將失敗。如果張量是葉子,那麼只有當用戶明確設定時,它才具有 requires_grad=True。換句話說,如果張量的至少一個輸入需要梯度,那麼它也會需要梯度。

此規則有兩個例外:

  1. 任何具有 nn.Parameternn.Module 的引數將具有 requires_grad=True(參見 此處)。

  2. 使用上下文管理器本地停用梯度計算(參見 此處)。

總之,requires_grad 告訴 autograd 需要為反向傳播計算哪些張量的梯度。這不同於哪些張量的 grad 欄位會被填充,這是下一節的主題。

retain_grad#

為了實際執行最佳化(例如 SGD、Adam 等),我們需要執行反向傳播以便提取梯度。

backward() 呼叫會填充所有具有 requires_grad=True 的葉子張量的 grad 欄位。 grad 是損失相對於我們正在探測的張量的梯度。在執行 backward() 之前,此屬性設定為 None

print(f"{W.grad=}")
print(f"{b.grad=}")
W.grad=tensor([[3., 3.],
        [3., 3.],
        [3., 3.]])
b.grad=tensor([[3., 3.]])

您可能想知道我們網路中的其他張量。讓我們檢查剩餘的葉子節點。

# prints all None because requires_grad=False
print(f"{x.grad=}")
print(f"{y.grad=}")
x.grad=None
y.grad=None

這些張量的梯度未被填充,因為我們沒有明確告訴 PyTorch 計算它們的梯度(requires_grad=False)。

現在讓我們看一箇中間非葉子節點。

print(f"{z.grad=}")
/var/lib/workspace/beginner_source/understanding_leaf_vs_nonleaf_tutorial.py:215: UserWarning:

The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more information. (Triggered internally at /pytorch/build/aten/src/ATen/core/TensorBody.h:489.)

z.grad=None

PyTorch 返回 None 作為梯度,並警告我們正在訪問非葉子節點的 grad 屬性。雖然 autograd 必須計算中間梯度才能使反向傳播正常工作,但它假定您之後不需要訪問這些值。要更改此行為,我們可以對張量使用 retain_grad() 函式。這會告訴 autograd 引擎在呼叫 backward() 後填充該張量的 grad

# we have to re-run the forward pass
z = (x @ W) + b
y_pred = F.relu(z)
loss = F.mse_loss(y_pred, y)

# tell PyTorch to store the gradients after backward()
z.retain_grad()
y_pred.retain_grad()
loss.retain_grad()

# have to zero out gradients otherwise they would accumulate
W.grad = None
b.grad = None

# backpropagation
loss.backward()

# print gradients for all tensors that have requires_grad=True
print(f"{W.grad=}")
print(f"{b.grad=}")
print(f"{z.grad=}")
print(f"{y_pred.grad=}")
print(f"{loss.grad=}")
W.grad=tensor([[3., 3.],
        [3., 3.],
        [3., 3.]])
b.grad=tensor([[3., 3.]])
z.grad=tensor([[3., 3.]])
y_pred.grad=tensor([[3., 3.]])
loss.grad=tensor(1.)

我們獲得的 W.grad 結果與之前相同。另外請注意,由於損失是標量,損失相對於其自身的梯度就是 1.0

如果我們檢視計算圖現在的狀態,我們會發現中間張量的 retains_grad 屬性已發生更改。按約定,此屬性將列印 False 對於任何葉子節點,即使它需要其梯度。

Computational graph after backward pass

反向傳播後的計算圖#

如果您對非葉子節點呼叫 retain_grad(),則不會產生任何效果。如果我們對具有 requires_grad=False 的節點呼叫 retain_grad(),PyTorch 實際上會丟擲錯誤,因為它無法儲存梯度(如果它從未被計算過)。

>>> x.retain_grad()
RuntimeError: can't retain_grad on Tensor that has requires_grad=False

摘要表#

使用 retain_grad()retains_grad 僅對非葉子節點有意義,因為對於具有 requires_grad=True 的葉子張量,grad 屬性已經填充。預設情況下,這些非葉子節點在反向傳播後不保留(儲存)其梯度。我們可以透過重新執行正向傳播,告訴 PyTorch 儲存梯度,然後執行反向傳播來更改此行為。

下表可用作參考,總結了上述討論。以下場景是 PyTorch 張量唯一有效的場景。

is_leaf

requires_grad

retains_grad

require_grad()

retain_grad()

requires_grad 設定為 TrueFalse

無操作

requires_grad 設定為 TrueFalse

無操作

無操作

retains_grad 設定為 True

無操作

無操作

結論#

在本教程中,我們涵蓋了 PyTorch 何時以及如何為葉子和非葉子張量計算梯度。透過使用 retain_grad,我們可以訪問 autograd 計算圖中中間張量的梯度。

如果您想了解更多關於 PyTorch 的 autograd 系統如何工作的資訊,請訪問下面的 參考資料。如果您對此教程有任何反饋(改進、拼寫錯誤修復等),請使用 PyTorch 論壇 和/或 issue 跟蹤器 聯絡我們。

參考資料#

指令碼總執行時間: (0 分鐘 0.321 秒)