評價此頁

torch.jit.trace#

torch.jit.trace(func, example_inputs=None, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_kwarg_inputs=None, _store_inputs=True)[source]#

跟蹤函式並返回一個可執行的或 ScriptFunction,它將使用即時編譯進行最佳化。

跟蹤最適合僅處理 Tensor 和包含 Tensor 的列表、字典和元組的程式碼。

使用 torch.jit.tracetorch.jit.trace_module,您可以將現有的模組或 Python 函式轉換為 TorchScript ScriptFunctionScriptModule。您必須提供示例輸入,我們會執行該函式,記錄所有張量上執行的操作。

  • 對獨立函式進行的由此產生的記錄會生成 ScriptFunction

  • nn.Module.forwardnn.Module 進行由此產生的記錄會生成 ScriptModule

此模組還包含原始模組的任何引數。

警告

跟蹤只能正確記錄不依賴於資料的函式和模組(例如,不包含張量中資料的條件判斷),並且不包含任何未跟蹤的外部依賴項(例如,執行輸入/輸出或訪問全域性變數)。跟蹤僅記錄給定函式在給定張量上執行時執行的操作。因此,返回的 ScriptModule 將始終在任何輸入上執行相同的跟蹤圖。當您的模組預期執行不同操作集時,這有一些重要的含義,具體取決於輸入和/或模組狀態。例如,

  • 跟蹤不會記錄任何控制流,例如 if 語句或迴圈。當該控制流在您的模組中是常量時,這是沒問題的,並且它通常會內聯控制流決策。但有時控制流實際上是模型本身的一部分。例如,迴圈神經網路是圍繞輸入序列(可能動態)長度的迴圈。

  • 在返回的 ScriptModule 中,在 trainingeval 模式下行為不同的操作將始終表現為在跟蹤期間處於的模式,無論 ScriptModule 處於哪種模式。

在這些情況下,跟蹤將不適用,而 指令碼化 是更好的選擇。如果您跟蹤這些模型,您可能會在後續呼叫模型時默默地獲得不正確的結果。跟蹤器會嘗試在執行可能導致生成錯誤跟蹤的操作時發出警告。

引數

func (callable or torch.nn.Module) – 將使用 example_inputs 執行的 Python 函式或 torch.nn.Modulefunc 的引數和返回值必須是張量,或包含張量的(可能巢狀的)元組。當將模組傳遞給 torch.jit.trace 時,僅執行和跟蹤 forward 方法(有關詳細資訊,請參閱 torch.jit.trace)。

關鍵字引數
  • example_inputs (tuple or torch.Tensor or None, optional) – 在跟蹤時將傳遞給函式的示例輸入元組。預設為 None。應指定此引數或 example_kwarg_inputs。生成的跟蹤可以與不同型別和形狀的輸入一起執行,前提是跟蹤的操作支援這些型別和形狀。 example_inputs 也可以是單個張量,在這種情況下它會自動包裝在元組中。當值為 None 時,應指定 example_kwarg_inputs

  • check_trace (bool, optional) – 檢查透過跟蹤程式碼執行的相同輸入是否產生相同的輸出。預設為 True。如果您需要停用此選項,例如,如果您的網路包含非確定性操作,或者如果您確信網路是正確的(儘管檢查器失敗)。

  • check_inputs (list of tuples, optional) – 用於將跟蹤與預期進行比較的一組輸入引數的元組列表。每個元組相當於在 example_inputs 中指定的輸入引數集。為了獲得最佳結果,請傳入一組代表網路預期輸入的形狀和型別空間的檢查輸入。如果未指定,則使用原始 example_inputs 進行檢查。

  • check_tolerance (float, optional) – 檢查器過程中使用的浮點數比較容差。在已知原因(例如,運算元融合)導致結果在數值上發生分歧的情況下,可以使用此選項來放寬檢查器的嚴格性。

  • strict (bool, optional) – 以嚴格模式執行跟蹤器或不執行(預設為 True)。僅當您希望跟蹤器記錄您的可變容器型別(當前是 list/dict)並且您確信您在問題中使用的容器是 constant 結構並且不被用作控制流(if, for)條件時,才將其關閉。

  • example_kwarg_inputs (dict, optional) – 此引數是在跟蹤時傳遞給函式的示例輸入的關鍵字引數包。預設為 None。應指定此引數或 example_inputs。字典將透過跟蹤函式的引數名稱進行解包。如果字典的鍵與跟蹤函式的引數名稱不匹配,將引發執行時異常。

返回

如果 funcnn.Modulenn.Moduleforward,則 trace 返回一個具有單個 forward 方法的 ScriptModule 物件,該方法包含跟蹤的程式碼。返回的 ScriptModule 將具有與原始 nn.Module 相同的子模組和引數集。如果 func 是獨立函式,則 trace 返回 ScriptFunction

示例(跟蹤函式)

import torch

def foo(x, y):
    return 2 * x + y

# Run `foo` with the provided inputs and record the tensor operations
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

# `traced_foo` can now be run with the TorchScript interpreter or saved
# and loaded in a Python-free environment

示例(跟蹤現有模組)

import torch
import torch.nn as nn


class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)


n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)

# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)

# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)