torch.jit.trace_module#
- torch.jit.trace_module(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_inputs_is_kwarg=False, _store_inputs=True)[source]#
跟蹤一個模組並返回一個透過即時編譯最佳化的可執行
ScriptModule。當一個模組被傳遞給
torch.jit.trace時,只有forward方法會被執行和跟蹤。使用trace_module,您可以指定一個方法名到示例輸入的字典來跟蹤(請參閱下面的inputs引數)。有關跟蹤的更多資訊,請參閱
torch.jit.trace。- 引數
mod (torch.nn.Module) – 一個
torch.nn.Module,其中包含在inputs中指定的方法名。給定的方法將被編譯為單個 ScriptModule 的一部分。inputs (dict) – 一個包含由
mod中方法名索引的示例輸入的字典。在跟蹤時,輸入將被傳遞給名稱與輸入的鍵對應的那些方法。{ 'forward' : example_forward_input, 'method2': example_method2_input}
- 關鍵字引數
check_trace (
bool, optional) – 檢查相同的輸入透過跟蹤程式碼是否產生相同的輸出。預設值:True。如果您希望停用此選項,例如,因為您的網路包含非確定性操作,或者如果您確信網路是正確的,儘管檢查器失敗。check_inputs (list of dicts, optional) – 一個輸入引數字典列表,用於根據預期檢查跟蹤。每個元組等同於在
inputs中指定的一組輸入引數。為了獲得最佳結果,請提供一組能夠代表網路預期輸入的形狀和型別的檢查輸入。如果未指定,則使用原始inputs進行檢查。check_tolerance (float, optional) – 在檢查器過程中使用的浮點數比較容差。當由於已知原因(例如運算子融合)導致結果在數值上出現偏差時,可以使用此引數來放寬檢查器的嚴格性。
example_inputs_is_kwarg (
bool, optional) – 此引數指示示例輸入是否為關鍵字引數的集合。預設值:False。
- 返回
一個
ScriptModule物件,其中包含一個forward方法,該方法包含跟蹤的程式碼。當func是一個torch.nn.Module時,返回的ScriptModule將具有與func相同的子模組和引數集。
示例(跟蹤具有多個方法的模組)
import torch import torch.nn as nn class Net(nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(1, 1, 3) def forward(self, x): return self.conv(x) def weighted_kernel_sum(self, weight): return weight * self.conv.weight n = Net() example_weight = torch.rand(1, 1, 3, 3) example_forward_input = torch.rand(1, 1, 3, 3) # Trace a specific method and construct `ScriptModule` with # a single `forward` method module = torch.jit.trace(n.forward, example_forward_input) # Trace a module (implicitly traces `forward`) and construct a # `ScriptModule` with a single `forward` method module = torch.jit.trace(n, example_forward_input) # Trace specific methods on a module (specified in `inputs`), constructs # a `ScriptModule` with `forward` and `weighted_kernel_sum` methods inputs = { "forward": example_forward_input, "weighted_kernel_sum": example_weight, } module = torch.jit.trace_module(n, inputs)