torch.autograd.Function.backward#
- static Function.backward(ctx, *grad_outputs)[source]#
定義使用反向模式自動微分來區分操作的公式。
此函式應被所有子類重寫。(定義此函式等同於定義
vjp函式。)它必須接受一個上下文
ctx作為第一個引數,然後是forward()返回的輸出數量(對於 forward 函式的非張量輸出將傳入 None),並且它應該返回與forward()的輸入數量相等的張量。每個引數是關於給定輸出的梯度,並且每個返回值應該是關於相應輸入的梯度。如果輸入不是張量或是不需要梯度的張量,你可以只為該輸入傳入 None 作為梯度。上下文可用於檢索在 forward 傳播期間儲存的張量。它還有一個屬性
ctx.needs_input_grad,它是一個布林值元組,表示每個輸入是否需要梯度。例如,如果forward()的第一個輸入需要計算關於輸出的梯度,那麼backward()將會有ctx.needs_input_grad[0] = True。- 返回型別