評價此頁

torch.autograd.Function.forward#

static Function.forward(*args, **kwargs)[source]#

定義自定義自動微分函式的前向傳播。

此函式應被所有子類覆蓋。定義 forward 的兩種方法:

用法 1 (合併 forward 和 ctx)

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass

用法 2 (分開 forward 和 ctx)

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass


@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • forward 不再接受 ctx 引數。

  • 相反,您還必須覆蓋 torch.autograd.Function.setup_context() 靜態方法來處理 ctx 物件的設定。 output 是 forward 的輸出,inputs 是 forward 輸入的元組。

  • 有關更多詳細資訊,請參閱 擴充套件 torch.autograd

上下文可用於儲存可以在反向傳播期間檢索的任意資料。不應直接在 ctx 上儲存張量(儘管出於向後相容性目前不強制執行此操作)。而是應使用 ctx.save_for_backward() 儲存張量(如果打算在 backward(等同於 vjp)中使用),或者使用 ctx.save_for_forward() 儲存張量(如果打算在 jvp 中使用)。

返回型別

任何