評價此頁

torch.autograd.Function.backward#

static Function.backward(ctx, *grad_outputs)[source]#

定義使用反向模式自動微分來區分操作的公式。

此函式應被所有子類重寫。(定義此函式等同於定義 vjp 函式。)

它必須接受一個上下文 ctx 作為第一個引數,然後是 forward() 返回的輸出數量(對於 forward 函式的非張量輸出將傳入 None),並且它應該返回與 forward() 的輸入數量相等的張量。每個引數是關於給定輸出的梯度,並且每個返回值應該是關於相應輸入的梯度。如果輸入不是張量或是不需要梯度的張量,你可以只為該輸入傳入 None 作為梯度。

上下文可用於檢索在 forward 傳播期間儲存的張量。它還有一個屬性 ctx.needs_input_grad,它是一個布林值元組,表示每個輸入是否需要梯度。例如,如果 forward() 的第一個輸入需要計算關於輸出的梯度,那麼 backward() 將會有 ctx.needs_input_grad[0] = True

返回型別

任何