評價此頁

InplaceFunction#

class torch.autograd.function.InplaceFunction(inplace=False)[source]#

該類僅出於向後相容原因而存在。對於任何新的用例,請使用 Function 而不是這個。

static backward(ctx, *grad_outputs)[source]#

定義使用反向模式自動微分來區分操作的公式。

此函式應被所有子類重寫。(定義此函式等同於定義 vjp 函式。)

它必須接受一個上下文 ctx 作為第一個引數,然後是 forward() 返回的任意數量的輸出(對於 forward 函式的非 tensor 輸出將傳入 None),並且它應該返回與 forward() 的輸入相同數量的 tensor。每個引數是相對於給定輸出的梯度,每個返回值應該是相對於相應輸入的梯度。如果輸入不是 Tensor 或是一個不需要梯度的 Tensor,你可以只為該輸入傳遞 None 作為梯度。

上下文可用於檢索在 forward 傳播期間儲存的 tensor。它還有一個屬性 ctx.needs_input_grad,它是一個布林元組,表示每個輸入是否需要梯度。例如,如果 forward() 的第一個輸入需要計算相對於輸出的梯度,那麼 backward() 將有 ctx.needs_input_grad[0] = True

返回型別

任何

static forward(*args, **kwargs)[source]#

定義自定義自動微分函式的前向傳播。

此函式應被所有子類覆蓋。定義 forward 有兩種方式

用法 1 (組合 forward 和 ctx)

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass

用法 2 (分離 forward 和 ctx)

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass


@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • forward 不再接受 ctx 引數。

  • 相反,您還必須覆蓋 torch.autograd.Function.setup_context() 靜態方法來處理 ctx 物件的設定。output 是 forward 的輸出,inputs 是一個包含 forward 輸入的元組。

  • 有關更多詳細資訊,請參閱 擴充套件 torch.autograd

上下文可用於儲存可以在 backward 傳播期間檢索的任意資料。不應直接將 tensor 儲存在 ctx 上(儘管出於向後相容性原因,目前不強制執行)。相反,tensor 應使用 ctx.save_for_backward() 儲存(如果它們打算在 backward 中使用(等價於 vjp))或使用 ctx.save_for_forward() 儲存(如果它們打算在 jvp 中使用)。

返回型別

任何

static jvp(ctx, *grad_inputs)[source]#

定義使用前向模式自動微分來區分操作的公式。

此函式應被所有子類覆蓋。它必須接受一個上下文 ctx 作為第一個引數,然後是 forward() 接收的任意數量的輸入(對於 forward 函式的非 tensor 輸入將傳入 None),並且它應該返回與 forward() 的輸出相同數量的 tensor。每個引數是相對於給定輸入的梯度,每個返回值應該是相對於相應輸出的梯度。如果輸出不是 Tensor 或函式相對於該輸出不可微分,你可以只為該輸入傳遞 None 作為梯度。

You can use the ctx object to pass any value from the forward to this functions.

返回型別

任何

mark_dirty(*args)[source]#

將給定張量標記為在就地操作中已修改。

應在 setup_context()forward() 方法中最多呼叫一次,所有引數都應該是輸入。

forward() 呼叫中以原地方式修改的每個 tensor 都應傳遞給此函式,以確保我們的檢查的正確性。函式是在修改之前還是之後呼叫的並不重要。

示例:
>>> class Inplace(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         x_npy = x.numpy() # x_npy shares storage with x
>>>         x_npy += 1
>>>         ctx.mark_dirty(x)
>>>         return x
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, grad_output):
>>>         return grad_output
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()
>>> b = a * a
>>> Inplace.apply(a)  # This would lead to wrong gradients!
>>>                   # but the engine would not know unless we mark_dirty
>>> b.backward() # RuntimeError: one of the variables needed for gradient
>>>              # computation has been modified by an inplace operation
mark_non_differentiable(*args)[source]#

將輸出標記為不可微分。

應在 setup_context()forward() 方法中最多呼叫一次,所有引數都應該是 tensor 輸出。

這將把輸出標記為不需要梯度,從而提高反向傳播計算的效率。你仍然需要在 backward() 中接受每個輸出的梯度,但它始終是一個與相應輸出形狀相同的零張量。

此功能用於例如從排序返回的索引。請參閱示例:
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         sorted, idx = x.sort()
>>>         ctx.mark_non_differentiable(idx)
>>>         ctx.save_for_backward(x, idx)
>>>         return sorted, idx
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):  # still need to accept g2
>>>         x, idx = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         grad_input.index_add_(0, idx, g1)
>>>         return grad_input
save_for_backward(*tensors)[source]#

為未來的 backward() 呼叫儲存給定的張量。

save_for_backward 應在 setup_context()forward() 方法中最多呼叫一次,並且只能使用 tensor。

所有打算在 backward 傳播中使用但不是 forward 函式的輸入或輸出的 tensor 都應使用 save_for_backward 儲存(而不是直接儲存在 ctx 上),以防止梯度不正確和記憶體洩漏,並啟用已儲存 tensor hook 的應用。請參閱 torch.autograd.graph.saved_tensors_hooks。有關更多詳細資訊,請參閱 擴充套件 torch.autograd

注意,如果儲存了中間 tensor(不是 forward 函式的輸入或輸出的 tensor)用於 backward,您的自定義 Function 可能不支援 double backward。不支援 double backward 的自定義 Function 應該使用 @once_differentiable 裝飾其 backward() 方法,以便執行 double backward 時會引發錯誤。如果您想支援 double backward,可以:在 backward 時根據輸入重新計算中間項,或將中間項作為自定義 Function 的輸出返回。有關更多詳細資訊,請參閱 double backward 教程

backward() 中,可以透過 saved_tensors 屬性訪問已儲存的 tensor。在將它們返回給使用者之前,會進行檢查以確保它們沒有被用於任何修改其內容的原地操作。

引數也可以是 None。這不會執行任何操作。

有關如何使用此方法的更多詳細資訊,請參閱 擴充套件 torch.autograd

示例

>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>>         w = x * z
>>>         out = x * y + y * z + w * y
>>>         ctx.save_for_backward(x, y, w, out)
>>>         ctx.z = z  # z is not a tensor
>>>         return out
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, grad_out):
>>>         x, y, w, out = ctx.saved_tensors
>>>         z = ctx.z
>>>         gx = grad_out * (y + y * z)
>>>         gy = grad_out * (x + z + w)
>>>         gz = None
>>>         return gx, gy, gz
>>>
>>> a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>> b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>> c = 4
>>> d = Func.apply(a, b, c)
save_for_forward(*tensors)[source]#

Save given tensors for a future call to jvp().

save_for_forward 應在 setup_context()forward() 方法中最多呼叫一次,所有引數都應該是 tensor。

jvp() 中,可以透過 saved_tensors 屬性訪問已儲存的物件。

引數也可以是 None。這不會執行任何操作。

有關如何使用此方法的更多詳細資訊,請參閱 擴充套件 torch.autograd

示例

>>> class Func(torch.autograd.Function):
>>>     @staticmethod
>>>     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
>>>         ctx.save_for_backward(x, y)
>>>         ctx.save_for_forward(x, y)
>>>         ctx.z = z
>>>         return x * y * z
>>>
>>>     @staticmethod
>>>     def jvp(ctx, x_t, y_t, _):
>>>         x, y = ctx.saved_tensors
>>>         z = ctx.z
>>>         return z * (y * x_t + x * y_t)
>>>
>>>     @staticmethod
>>>     def vjp(ctx, grad_out):
>>>         x, y = ctx.saved_tensors
>>>         z = ctx.z
>>>         return z * grad_out * y, z * grad_out * x, None
>>>
>>>     a = torch.tensor(1., requires_grad=True, dtype=torch.double)
>>>     t = torch.tensor(1., dtype=torch.double)
>>>     b = torch.tensor(2., requires_grad=True, dtype=torch.double)
>>>     c = 4
>>>
>>>     with fwAD.dual_level():
>>>         a_dual = fwAD.make_dual(a, t)
>>>         d = Func.apply(a_dual, b, c)
set_materialize_grads(value)[source]#

Set whether to materialize grad tensors. Default is True.

這應該只從 setup_context()forward() 方法中呼叫。

如果為 True,則在呼叫 backward()jvp() 方法之前,未定義的 grad tensor 將被擴充套件為全零 tensor。

示例

>>> class SimpleFunc(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         return g1 + g2  # No check for None necessary
>>>
>>> # We modify SimpleFunc to handle non-materialized grad outputs
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         ctx.set_materialize_grads(False)
>>>         ctx.save_for_backward(x)
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         x, = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         if g1 is not None:  # We must check for None now
>>>             grad_input += g1
>>>         if g2 is not None:
>>>             grad_input += g2
>>>         return grad_input
>>>
>>> a = torch.tensor(1., requires_grad=True)
>>> b, _ = Func.apply(a)  # induces g2 to be undefined
static setup_context(ctx, inputs, output)[source]#

定義 autograd.Function 的 forward 傳播有兩種方式。

Either

  1. 使用簽名 forward(ctx, *args, **kwargs) 覆蓋 forward。不覆蓋 setup_context。ctx 的設定用於 backward 在 forward 中完成。

  2. 使用簽名 forward(*args, **kwargs) 覆蓋 forward 並覆蓋 setup_context。ctx 的設定用於 backward 在 setup_context 中完成(而不是在 forward 中)。

有關更多詳細資訊,請參閱 torch.autograd.Function.forward()擴充套件 torch.autograd

返回型別

任何

static vjp(ctx, *grad_outputs)[source]#

定義使用反向模式自動微分來區分操作的公式。

此函式應被所有子類重寫。(定義此函式等同於定義 vjp 函式。)

它必須接受一個上下文 ctx 作為第一個引數,然後是 forward() 返回的任意數量的輸出(對於 forward 函式的非 tensor 輸出將傳入 None),並且它應該返回與 forward() 的輸入相同數量的 tensor。每個引數是相對於給定輸出的梯度,每個返回值應該是相對於相應輸入的梯度。如果輸入不是 Tensor 或是一個不需要梯度的 Tensor,你可以只為該輸入傳遞 None 作為梯度。

上下文可用於檢索在 forward 傳播期間儲存的 tensor。它還有一個屬性 ctx.needs_input_grad,它是一個布林元組,表示每個輸入是否需要梯度。例如,如果 forward() 的第一個輸入需要計算相對於輸出的梯度,那麼 backward() 將有 ctx.needs_input_grad[0] = True

返回型別

任何

static vmap(info, in_dims, *args)[source]#

定義此 autograd.Function 在 torch.vmap() 下的行為。

要使 torch.autograd.Function() 支援 torch.vmap(),您必須覆蓋此靜態方法,或者將 generate_vmap_rule 設定為 True(您不能同時執行這兩項)。

如果您選擇重寫此靜態方法:它必須接受

  • 第一個引數是一個 info 物件。info.batch_size 指定了要 vmap 的維度的大小,而 info.randomness 是傳遞給 torch.vmap() 的隨機性選項。

  • 第二個引數是一個 in_dims 元組。對於 args 中的每個 arg,in_dims 有一個相應的 Optional[int]。如果 arg 不是 Tensor 或 arg 不被 vmap,則為 None,否則,它是一個指定 Tensor 的哪個維度被 vmap 的整數。

  • *args,與 forward() 的 args 相同。

vmap 靜態方法的返回值是一個元組 (output, out_dims)。與 in_dims 類似,out_dims 的結構應與 output 相同,並且每個輸出都包含一個 out_dim,指定輸出是否具有 vmap 的維度以及在該維度中的索引。

有關更多詳細資訊,請參閱 使用 autograd.Function 擴充套件 torch.func