torch.autograd.Function.forward#
- static Function.forward(*args, **kwargs)[source]#
定義自定義自動微分函式的前向傳播。
此函式應被所有子類覆蓋。定義 forward 的兩種方法:
用法 1 (合併 forward 和 ctx)
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
它必須接受一個 context ctx 作為第一個引數,後跟任意數量的引數(張量或其他型別)。
有關更多詳細資訊,請參閱 合併或分開 forward() 和 setup_context()。
用法 2 (分開 forward 和 ctx)
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
forward 不再接受 ctx 引數。
相反,您還必須覆蓋
torch.autograd.Function.setup_context()靜態方法來處理ctx物件的設定。output是 forward 的輸出,inputs是 forward 輸入的元組。有關更多詳細資訊,請參閱 擴充套件 torch.autograd。
上下文可用於儲存可以在反向傳播期間檢索的任意資料。不應直接在 ctx 上儲存張量(儘管出於向後相容性目前不強制執行此操作)。而是應使用
ctx.save_for_backward()儲存張量(如果打算在backward(等同於vjp)中使用),或者使用ctx.save_for_forward()儲存張量(如果打算在jvp中使用)。- 返回型別