評價此頁

torch.utils.checkpoint#

創建於:2025 年 6 月 16 日 | 最後更新於:2025 年 6 月 16 日

注意

檢查點是透過在反向傳播期間為每個檢查點段重新執行前向傳遞段來實現的。這可能導致諸如 RNG 狀態之類的持久狀態比沒有檢查點時更超前。預設情況下,檢查點包含用於管理 RNG 狀態的邏輯,以便使用 RNG(例如透過 dropout)的檢查點傳遞與非檢查點傳遞相比具有確定性的輸出。 the logic to stash and restore RNG states can incur a moderate performance hit depending on the runtime of checkpointed operations. 如果不需要與非檢查點傳遞相比具有確定性的輸出,請將 `preserve_rng_state=False` 傳遞給 `checkpoint` 或 `checkpoint_sequential`,以省略在每個檢查點期間儲存和恢復 RNG 狀態。

儲存邏輯會為 CPU 和另一種裝置型別(透過 `_infer_device_type` 從除 CPU 張量之外的張量引數推斷裝置型別)儲存和恢復 RNG 狀態到 `run_fn`。如果存在多種裝置,則裝置狀態僅為一種裝置型別儲存,其餘裝置將被忽略。因此,如果任何被檢查點的函式涉及隨機性,這可能會導致梯度不正確。(請注意,如果檢測到的裝置中包含 CUDA 裝置,則會優先選擇 CUDA 裝置;否則,將選擇遇到的第一個裝置。)如果沒有 CPU 張量,將儲存和恢復預設裝置型別狀態(預設值為 `cuda`,可以透過 `DefaultDeviceType` 設定為其他裝置)。但是,該邏輯無法預測使用者是否會在 `run_fn` 內部將張量移動到新裝置。因此,如果在 `run_fn` 中移動張量到新裝置(“新”表示不屬於 [當前裝置 + 張量引數的裝置] 集合),則永遠無法保證與非檢查點傳遞相比具有確定性的輸出。

torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=None, context_fn=<function noop_context_fn>, determinism_check='default', debug=False, early_stop=True, **kwargs)[source]#

檢查點模型或模型的一部分。

啟用檢查點是一種以計算換記憶體的技術。在反向傳播期間計算梯度時,檢查點區域的前向計算不儲存用於反向傳播的張量,而是在反向傳播期間重新計算它們。啟用檢查點可以應用於模型的任何部分。

目前提供兩種檢查點實現,由 `use_reentrant` 引數確定。建議使用 `use_reentrant=False`。請參閱下面的註釋以討論它們的區別。

警告

如果反向傳播期間的 `function` 呼叫與前向傳遞不同,例如由於全域性變數,則檢查點版本可能不相等,可能導致引發錯誤或導致靜默的錯誤梯度。

警告

`use_reentrant` 引數應顯式傳遞。在 2.9 版本中,如果未傳遞 `use_reentrant`,我們將引發異常。如果您正在使用 `use_reentrant=True` 變體,請參閱下面的註釋以瞭解重要的注意事項和潛在限制。

注意

檢查點的可重入變體 (`use_reentrant=True`) 和非可重入變體 (`use_reentrant=False`) 在以下方面有所不同:

  • 非可重入檢查點在重新計算完所有必需的中間啟用後立即停止重新計算。此功能預設啟用,但可以透過 `set_checkpoint_early_stop()` 停用。可重入檢查點在反向傳播期間始終完全重新計算 `function`。

  • 可重入變體在前向傳遞期間不記錄 autograd 圖,因為它在 `torch.no_grad()` 下執行前向傳遞。非可重入版本會記錄 autograd 圖,允許在檢查點區域內的圖上執行反向傳播。

  • 可重入檢查點僅支援 `torch.autograd.backward()` API 進行反向傳播,而不帶其 `inputs` 引數,而後非可重入版本支援所有方式的反向傳播。

  • 可重入變體至少需要一個輸入和輸出具有 `requires_grad=True`。如果此條件不滿足,則模型的檢查點部分將沒有梯度。非可重入版本沒有此要求。

  • 可重入版本不將巢狀結構(例如,自定義物件、列表、字典等)中的張量視為參與 autograd,而非可重入版本則將其視為參與。

  • 可重入檢查點不支援包含從計算圖中分離的張量的檢查點區域,而非常可重入版本支援。對於可重入變體,如果檢查點段包含使用 `detach()` 或 `torch.no_grad()` 分離的張量,則反向傳播將引發錯誤。這是因為 `checkpoint` 使所有輸出都需要梯度,當模型中定義了一個張量不應有梯度時,這會導致問題。為避免此問題,請在 `checkpoint` 函式外部分離張量。

引數
  • function – 描述模型或模型一部分前向傳遞中執行的內容。它還應該知道如何處理作為元組傳遞的輸入。例如,在 LSTM 中,如果使用者傳遞 `(activation, hidden)`,`function` 應該正確地將第一個輸入用作 `activation`,第二個輸入用作 `hidden`。

  • args – 包含 `function` 輸入的元組。

關鍵字引數
  • preserve_rng_state (bool, optional) – 在每個檢查點期間省略儲存和恢復 RNG 狀態。請注意,在 `torch.compile` 下,此標誌無效,我們始終保留 RNG 狀態。預設值: `True`。

  • use_reentrant (bool) – 指定是否使用需要可重入 autograd 的啟用檢查點變體。此引數應顯式傳遞。在 2.9 版本中,如果未傳遞 `use_reentrant`,我們將引發異常。如果 `use_reentrant=False`,`checkpoint` 將使用不需要可重入 autograd 的實現。這使得 `checkpoint` 支援其他功能,例如與 `torch.autograd.grad` 正常工作以及支援傳遞給檢查點函式的關鍵字引數。

  • context_fn (Callable, optional) – 返回兩個上下文管理器元組的可呼叫物件。函式及其重計算將在第一個和第二個上下文管理器下分別執行。此引數僅在 `use_reentrant=False` 時受支援。

  • determinism_check (str, optional) – 指定要執行的確定性檢查的字串。預設設定為 `"default"`,它將重計算張量的形狀、資料型別和裝置與儲存的張量進行比較。要關閉此檢查,請指定 `"none"`。目前這是唯一支援的兩個值。如果您希望看到更多確定性檢查,請開啟一個 issue。此引數僅在 `use_reentrant=False` 時受支援。如果 `use_reentrant=True`,則確定性檢查始終停用。

  • debug (bool, optional) – 如果為 `True`,錯誤訊息還將包括在原始前向計算和重計算期間執行的操作的跟蹤。此引數僅在 `use_reentrant=False` 時受支援。

  • early_stop (bool, optional) – 如果為 `True`,非可重入檢查點會盡快停止重計算,因為它已經計算了所有需要的張量。如果 `use_reentrant=True`,則忽略此引數。可以使用 `set_checkpoint_early_stop()` 上下文管理器全域性覆蓋。預設值: `True`。

返回

在 `*args` 上執行 `function` 的輸出。

torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs)[source]#

檢查點順序模型以節省記憶體。

順序模型按順序(順序)執行一系列模組/函式。因此,我們可以將此類模型劃分為多個段並檢查每個段。除最後一個段外,所有段都不會儲存中間啟用。每個檢查點段的輸入將被儲存,以便在反向傳播期間重新執行該段。

警告

`use_reentrant` 引數應顯式傳遞。在 2.9 版本中,如果未傳遞 `use_reentrant`,我們將引發異常。如果您正在使用 `use_reentrant=True` 變體,請參閱 :func:`~torch.utils.checkpoint.checkpoint` 以瞭解此變體的​​重要注意事項和限制。建議使用 `use_reentrant=False`。

引數
  • functions – 一個 `torch.nn.Sequential` 或模組/函式列表(構成模型),按順序執行。

  • segments – 在模型中建立的塊的數量。

  • input – 輸入到 `functions` 的一個張量。

  • preserve_rng_state (bool, optional) – 在每個檢查點期間省略儲存和恢復 RNG 狀態。預設值: `True`。

  • use_reentrant (bool) – 指定是否使用需要可重入 autograd 的啟用檢查點變體。此引數應顯式傳遞。在 2.5 版本中,如果未傳遞 `use_reentrant`,我們將引發異常。如果 `use_reentrant=False`,`checkpoint` 將使用不需要可重入 autograd 的實現。這使得 `checkpoint` 支援其他功能,例如與 `torch.autograd.grad` 正常工作以及支援傳遞給檢查點函式的關鍵字引數。

返回

在 `*inputs` 上順序執行 `functions` 的輸出。

示例

>>> model = nn.Sequential(...)
>>> input_var = checkpoint_sequential(model, chunks, input_var)
torch.utils.checkpoint.set_checkpoint_debug_enabled(enabled)[source]#

上下文管理器,用於設定檢查點在執行時是否應列印額外的除錯資訊。有關更多資訊,請參閱 `checkpoint()` 的 `debug` 標誌。請注意,設定時,此上下文管理器會覆蓋傳遞給檢查點的 `debug` 值。要推遲到本地設定,請將 `None` 傳遞給此上下文。

引數

enabled (bool) – 檢查點是否應列印除錯資訊。預設為“None”。

class torch.utils.checkpoint.CheckpointPolicy(value)[source]#

用於指定反向傳播期間檢查點策略的列舉。

支援以下策略:

  • `{MUST,PREFER}_SAVE`: 操作的輸出將在前向傳遞期間儲存,並且不會在反向傳遞期間重新計算。

  • `{MUST,PREFER}_RECOMPUTE`: 操作的輸出不會在前向傳遞期間儲存,並且將在反向傳遞期間重新計算。

使用 `MUST_*` 而非 `PREFER_*` 來指示該策略不應被 `torch.compile` 等其他子系統覆蓋。

注意

一個始終返回 `PREFER_RECOMPUTE` 的策略函式等同於標準的檢查點。

一個始終返回 `PREFER_SAVE` 的策略函式不等於不使用檢查點。使用這樣的策略將儲存額外的張量,而不僅僅是那些實際上用於梯度計算的張量。

class torch.utils.checkpoint.SelectiveCheckpointContext(*, is_recompute)[source]#

在選擇性檢查點期間傳遞給策略函式的上下文。

此類用於在選擇性檢查點期間將相關元資料傳遞給策略函式。元資料包括策略函式當前呼叫是否在重計算期間。

示例

>>>
>>> def policy_fn(ctx, op, *args, **kwargs):
>>>    print(ctx.is_recompute)
>>>
>>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
>>>
>>> out = torch.utils.checkpoint.checkpoint(
>>>     fn, x, y,
>>>     use_reentrant=False,
>>>     context_fn=context_fn,
>>> )
torch.utils.checkpoint.create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False)[source]#

幫助程式,用於在啟用檢查點期間避免重新計算某些操作。

與 `torch.utils.checkpoint.checkpoint` 一起使用此功能,以控制在反向傳播期間重新計算哪些操作。

引數
  • policy_fn_or_list (Callable or List) –

    • 如果提供了策略函式,它應該接受一個 `SelectiveCheckpointContext`、`OpOverload`、op 的 args 和 kwargs,並返回一個 `CheckpointPolicy` 列舉值,指示該 op 的執行是否應被重新計算。

    • 如果提供了一個操作列表,它等同於一個策略函式,該函式為指定的操作返回 `CheckpointPolicy.MUST_SAVE`,為所有其他操作返回 `CheckpointPolicy.PREFER_RECOMPUTE`。

  • allow_cache_entry_mutation (bool, optional) – 預設情況下,如果以確保正確性為目的而突變任何由選擇性啟用檢查點快取的張量,則會引發錯誤。如果設定為 `True`,則停用此檢查。

返回

兩個上下文管理器的元組。

示例

>>> import functools
>>>
>>> x = torch.rand(10, 10, requires_grad=True)
>>> y = torch.rand(10, 10, requires_grad=True)
>>>
>>> ops_to_save = [
>>>    torch.ops.aten.mm.default,
>>> ]
>>>
>>> def policy_fn(ctx, op, *args, **kwargs):
>>>    if op in ops_to_save:
>>>        return CheckpointPolicy.MUST_SAVE
>>>    else:
>>>        return CheckpointPolicy.PREFER_RECOMPUTE
>>>
>>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
>>>
>>> # or equivalently
>>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save)
>>>
>>> def fn(x, y):
>>>     return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
>>>
>>> out = torch.utils.checkpoint.checkpoint(
>>>     fn, x, y,
>>>     use_reentrant=False,
>>>     context_fn=context_fn,
>>> )