InplaceFunction#
- class torch.autograd.function.InplaceFunction(inplace=False)[source]#
該類僅出於向後相容原因而存在。對於任何新的用例,請使用
Function而不是這個。- static backward(ctx, *grad_outputs)[source]#
定義使用反向模式自動微分來區分操作的公式。
此函式應被所有子類重寫。(定義此函式等同於定義
vjp函式。)它必須接受一個上下文
ctx作為第一個引數,然後是forward()返回的任意數量的輸出(對於 forward 函式的非 tensor 輸出將傳入 None),並且它應該返回與forward()的輸入相同數量的 tensor。每個引數是相對於給定輸出的梯度,每個返回值應該是相對於相應輸入的梯度。如果輸入不是 Tensor 或是一個不需要梯度的 Tensor,你可以只為該輸入傳遞 None 作為梯度。上下文可用於檢索在 forward 傳播期間儲存的 tensor。它還有一個屬性
ctx.needs_input_grad,它是一個布林元組,表示每個輸入是否需要梯度。例如,如果forward()的第一個輸入需要計算相對於輸出的梯度,那麼backward()將有ctx.needs_input_grad[0] = True。- 返回型別
- static forward(*args, **kwargs)[source]#
定義自定義自動微分函式的前向傳播。
此函式應被所有子類覆蓋。定義 forward 有兩種方式
用法 1 (組合 forward 和 ctx)
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
它必須接受一個 ctx 作為第一個引數,後跟任意數量的引數(tensor 或其他型別)。
有關更多詳細資訊,請參閱 組合或分離 forward() 和 setup_context()。
用法 2 (分離 forward 和 ctx)
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
forward 不再接受 ctx 引數。
相反,您還必須覆蓋
torch.autograd.Function.setup_context()靜態方法來處理ctx物件的設定。output是 forward 的輸出,inputs是一個包含 forward 輸入的元組。有關更多詳細資訊,請參閱 擴充套件 torch.autograd。
上下文可用於儲存可以在 backward 傳播期間檢索的任意資料。不應直接將 tensor 儲存在 ctx 上(儘管出於向後相容性原因,目前不強制執行)。相反,tensor 應使用
ctx.save_for_backward()儲存(如果它們打算在backward中使用(等價於vjp))或使用ctx.save_for_forward()儲存(如果它們打算在jvp中使用)。- 返回型別
- static jvp(ctx, *grad_inputs)[source]#
定義使用前向模式自動微分來區分操作的公式。
此函式應被所有子類覆蓋。它必須接受一個上下文
ctx作為第一個引數,然後是forward()接收的任意數量的輸入(對於 forward 函式的非 tensor 輸入將傳入 None),並且它應該返回與forward()的輸出相同數量的 tensor。每個引數是相對於給定輸入的梯度,每個返回值應該是相對於相應輸出的梯度。如果輸出不是 Tensor 或函式相對於該輸出不可微分,你可以只為該輸入傳遞 None 作為梯度。You can use the
ctxobject to pass any value from the forward to this functions.- 返回型別
- mark_dirty(*args)[source]#
將給定張量標記為在就地操作中已修改。
應在
setup_context()或forward()方法中最多呼叫一次,所有引數都應該是輸入。在
forward()呼叫中以原地方式修改的每個 tensor 都應傳遞給此函式,以確保我們的檢查的正確性。函式是在修改之前還是之後呼叫的並不重要。- 示例:
>>> class Inplace(Function): >>> @staticmethod >>> def forward(ctx, x): >>> x_npy = x.numpy() # x_npy shares storage with x >>> x_npy += 1 >>> ctx.mark_dirty(x) >>> return x >>> >>> @staticmethod >>> @once_differentiable >>> def backward(ctx, grad_output): >>> return grad_output >>> >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double).clone() >>> b = a * a >>> Inplace.apply(a) # This would lead to wrong gradients! >>> # but the engine would not know unless we mark_dirty >>> b.backward() # RuntimeError: one of the variables needed for gradient >>> # computation has been modified by an inplace operation
- mark_non_differentiable(*args)[source]#
將輸出標記為不可微分。
應在
setup_context()或forward()方法中最多呼叫一次,所有引數都應該是 tensor 輸出。這將把輸出標記為不需要梯度,從而提高反向傳播計算的效率。你仍然需要在
backward()中接受每個輸出的梯度,但它始終是一個與相應輸出形狀相同的零張量。- 此功能用於例如從排序返回的索引。請參閱示例:
>>> class Func(Function): >>> @staticmethod >>> def forward(ctx, x): >>> sorted, idx = x.sort() >>> ctx.mark_non_differentiable(idx) >>> ctx.save_for_backward(x, idx) >>> return sorted, idx >>> >>> @staticmethod >>> @once_differentiable >>> def backward(ctx, g1, g2): # still need to accept g2 >>> x, idx = ctx.saved_tensors >>> grad_input = torch.zeros_like(x) >>> grad_input.index_add_(0, idx, g1) >>> return grad_input
- save_for_backward(*tensors)[source]#
為未來的
backward()呼叫儲存給定的張量。save_for_backward應在setup_context()或forward()方法中最多呼叫一次,並且只能使用 tensor。所有打算在 backward 傳播中使用但不是 forward 函式的輸入或輸出的 tensor 都應使用
save_for_backward儲存(而不是直接儲存在ctx上),以防止梯度不正確和記憶體洩漏,並啟用已儲存 tensor hook 的應用。請參閱torch.autograd.graph.saved_tensors_hooks。有關更多詳細資訊,請參閱 擴充套件 torch.autograd。注意,如果儲存了中間 tensor(不是 forward 函式的輸入或輸出的 tensor)用於 backward,您的自定義 Function 可能不支援 double backward。不支援 double backward 的自定義 Function 應該使用
@once_differentiable裝飾其backward()方法,以便執行 double backward 時會引發錯誤。如果您想支援 double backward,可以:在 backward 時根據輸入重新計算中間項,或將中間項作為自定義 Function 的輸出返回。有關更多詳細資訊,請參閱 double backward 教程。在
backward()中,可以透過saved_tensors屬性訪問已儲存的 tensor。在將它們返回給使用者之前,會進行檢查以確保它們沒有被用於任何修改其內容的原地操作。引數也可以是
None。這不會執行任何操作。有關如何使用此方法的更多詳細資訊,請參閱 擴充套件 torch.autograd。
示例
>>> class Func(Function): >>> @staticmethod >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int): >>> w = x * z >>> out = x * y + y * z + w * y >>> ctx.save_for_backward(x, y, w, out) >>> ctx.z = z # z is not a tensor >>> return out >>> >>> @staticmethod >>> @once_differentiable >>> def backward(ctx, grad_out): >>> x, y, w, out = ctx.saved_tensors >>> z = ctx.z >>> gx = grad_out * (y + y * z) >>> gy = grad_out * (x + z + w) >>> gz = None >>> return gx, gy, gz >>> >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double) >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double) >>> c = 4 >>> d = Func.apply(a, b, c)
- save_for_forward(*tensors)[source]#
Save given tensors for a future call to
jvp().save_for_forward應在setup_context()或forward()方法中最多呼叫一次,所有引數都應該是 tensor。在
jvp()中,可以透過saved_tensors屬性訪問已儲存的物件。引數也可以是
None。這不會執行任何操作。有關如何使用此方法的更多詳細資訊,請參閱 擴充套件 torch.autograd。
示例
>>> class Func(torch.autograd.Function): >>> @staticmethod >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int): >>> ctx.save_for_backward(x, y) >>> ctx.save_for_forward(x, y) >>> ctx.z = z >>> return x * y * z >>> >>> @staticmethod >>> def jvp(ctx, x_t, y_t, _): >>> x, y = ctx.saved_tensors >>> z = ctx.z >>> return z * (y * x_t + x * y_t) >>> >>> @staticmethod >>> def vjp(ctx, grad_out): >>> x, y = ctx.saved_tensors >>> z = ctx.z >>> return z * grad_out * y, z * grad_out * x, None >>> >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double) >>> t = torch.tensor(1., dtype=torch.double) >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double) >>> c = 4 >>> >>> with fwAD.dual_level(): >>> a_dual = fwAD.make_dual(a, t) >>> d = Func.apply(a_dual, b, c)
- set_materialize_grads(value)[source]#
Set whether to materialize grad tensors. Default is
True.這應該只從
setup_context()或forward()方法中呼叫。如果為
True,則在呼叫backward()和jvp()方法之前,未定義的 grad tensor 將被擴充套件為全零 tensor。示例
>>> class SimpleFunc(Function): >>> @staticmethod >>> def forward(ctx, x): >>> return x.clone(), x.clone() >>> >>> @staticmethod >>> @once_differentiable >>> def backward(ctx, g1, g2): >>> return g1 + g2 # No check for None necessary >>> >>> # We modify SimpleFunc to handle non-materialized grad outputs >>> class Func(Function): >>> @staticmethod >>> def forward(ctx, x): >>> ctx.set_materialize_grads(False) >>> ctx.save_for_backward(x) >>> return x.clone(), x.clone() >>> >>> @staticmethod >>> @once_differentiable >>> def backward(ctx, g1, g2): >>> x, = ctx.saved_tensors >>> grad_input = torch.zeros_like(x) >>> if g1 is not None: # We must check for None now >>> grad_input += g1 >>> if g2 is not None: >>> grad_input += g2 >>> return grad_input >>> >>> a = torch.tensor(1., requires_grad=True) >>> b, _ = Func.apply(a) # induces g2 to be undefined
- static setup_context(ctx, inputs, output)[source]#
定義 autograd.Function 的 forward 傳播有兩種方式。
Either
使用簽名
forward(ctx, *args, **kwargs)覆蓋 forward。不覆蓋setup_context。ctx 的設定用於 backward 在forward中完成。使用簽名
forward(*args, **kwargs)覆蓋 forward 並覆蓋setup_context。ctx 的設定用於 backward 在setup_context中完成(而不是在forward中)。
有關更多詳細資訊,請參閱
torch.autograd.Function.forward()和 擴充套件 torch.autograd。- 返回型別
- static vjp(ctx, *grad_outputs)[source]#
定義使用反向模式自動微分來區分操作的公式。
此函式應被所有子類重寫。(定義此函式等同於定義
vjp函式。)它必須接受一個上下文
ctx作為第一個引數,然後是forward()返回的任意數量的輸出(對於 forward 函式的非 tensor 輸出將傳入 None),並且它應該返回與forward()的輸入相同數量的 tensor。每個引數是相對於給定輸出的梯度,每個返回值應該是相對於相應輸入的梯度。如果輸入不是 Tensor 或是一個不需要梯度的 Tensor,你可以只為該輸入傳遞 None 作為梯度。上下文可用於檢索在 forward 傳播期間儲存的 tensor。它還有一個屬性
ctx.needs_input_grad,它是一個布林元組,表示每個輸入是否需要梯度。例如,如果forward()的第一個輸入需要計算相對於輸出的梯度,那麼backward()將有ctx.needs_input_grad[0] = True。- 返回型別
- static vmap(info, in_dims, *args)[source]#
定義此 autograd.Function 在
torch.vmap()下的行為。要使
torch.autograd.Function()支援torch.vmap(),您必須覆蓋此靜態方法,或者將generate_vmap_rule設定為True(您不能同時執行這兩項)。如果您選擇重寫此靜態方法:它必須接受
第一個引數是一個
info物件。info.batch_size指定了要 vmap 的維度的大小,而info.randomness是傳遞給torch.vmap()的隨機性選項。第二個引數是一個
in_dims元組。對於args中的每個 arg,in_dims有一個相應的Optional[int]。如果 arg 不是 Tensor 或 arg 不被 vmap,則為None,否則,它是一個指定 Tensor 的哪個維度被 vmap 的整數。*args,與forward()的 args 相同。
vmap 靜態方法的返回值是一個元組
(output, out_dims)。與in_dims類似,out_dims的結構應與output相同,並且每個輸出都包含一個out_dim,指定輸出是否具有 vmap 的維度以及在該維度中的索引。有關更多詳細資訊,請參閱 使用 autograd.Function 擴充套件 torch.func。