評價此頁

torch.jit.trace_module#

torch.jit.trace_module(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_inputs_is_kwarg=False, _store_inputs=True)[source]#

跟蹤一個模組並返回一個透過即時編譯最佳化的可執行 ScriptModule

當一個模組被傳遞給 torch.jit.trace 時,只有 forward 方法會被執行和跟蹤。使用 trace_module,您可以指定一個方法名到示例輸入的字典來跟蹤(請參閱下面的 inputs 引數)。

有關跟蹤的更多資訊,請參閱 torch.jit.trace

引數
  • mod (torch.nn.Module) – 一個 torch.nn.Module,其中包含在 inputs 中指定的方法名。給定的方法將被編譯為單個 ScriptModule 的一部分。

  • inputs (dict) – 一個包含由 mod 中方法名索引的示例輸入的字典。在跟蹤時,輸入將被傳遞給名稱與輸入的鍵對應的那些方法。 { 'forward' : example_forward_input, 'method2': example_method2_input}

關鍵字引數
  • check_trace (bool, optional) – 檢查相同的輸入透過跟蹤程式碼是否產生相同的輸出。預設值:True。如果您希望停用此選項,例如,因為您的網路包含非確定性操作,或者如果您確信網路是正確的,儘管檢查器失敗。

  • check_inputs (list of dicts, optional) – 一個輸入引數字典列表,用於根據預期檢查跟蹤。每個元組等同於在 inputs 中指定的一組輸入引數。為了獲得最佳結果,請提供一組能夠代表網路預期輸入的形狀和型別的檢查輸入。如果未指定,則使用原始 inputs 進行檢查。

  • check_tolerance (float, optional) – 在檢查器過程中使用的浮點數比較容差。當由於已知原因(例如運算子融合)導致結果在數值上出現偏差時,可以使用此引數來放寬檢查器的嚴格性。

  • example_inputs_is_kwarg (bool, optional) – 此引數指示示例輸入是否為關鍵字引數的集合。預設值:False

返回

一個 ScriptModule 物件,其中包含一個 forward 方法,該方法包含跟蹤的程式碼。當 func 是一個 torch.nn.Module 時,返回的 ScriptModule 將具有與 func 相同的子模組和引數集。

示例(跟蹤具有多個方法的模組)

import torch
import torch.nn as nn


class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)

    def weighted_kernel_sum(self, weight):
        return weight * self.conv.weight


n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)

# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)

# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)

# Trace specific methods on a module (specified in `inputs`), constructs
# a `ScriptModule` with `forward` and `weighted_kernel_sum` methods
inputs = {
    "forward": example_forward_input,
    "weighted_kernel_sum": example_weight,
}
module = torch.jit.trace_module(n, inputs)