torch.jit.trace#
- torch.jit.trace(func, example_inputs=None, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_kwarg_inputs=None, _store_inputs=True)[source]#
跟蹤函式並返回一個可執行的或
ScriptFunction,它將使用即時編譯進行最佳化。跟蹤最適合僅處理
Tensor和包含Tensor的列表、字典和元組的程式碼。使用 torch.jit.trace 和 torch.jit.trace_module,您可以將現有的模組或 Python 函式轉換為 TorchScript
ScriptFunction或ScriptModule。您必須提供示例輸入,我們會執行該函式,記錄所有張量上執行的操作。對獨立函式進行的由此產生的記錄會生成 ScriptFunction。
對 nn.Module.forward 或 nn.Module 進行由此產生的記錄會生成 ScriptModule。
此模組還包含原始模組的任何引數。
警告
跟蹤只能正確記錄不依賴於資料的函式和模組(例如,不包含張量中資料的條件判斷),並且不包含任何未跟蹤的外部依賴項(例如,執行輸入/輸出或訪問全域性變數)。跟蹤僅記錄給定函式在給定張量上執行時執行的操作。因此,返回的 ScriptModule 將始終在任何輸入上執行相同的跟蹤圖。當您的模組預期執行不同操作集時,這有一些重要的含義,具體取決於輸入和/或模組狀態。例如,
跟蹤不會記錄任何控制流,例如 if 語句或迴圈。當該控制流在您的模組中是常量時,這是沒問題的,並且它通常會內聯控制流決策。但有時控制流實際上是模型本身的一部分。例如,迴圈神經網路是圍繞輸入序列(可能動態)長度的迴圈。
在返回的
ScriptModule中,在training和eval模式下行為不同的操作將始終表現為在跟蹤期間處於的模式,無論 ScriptModule 處於哪種模式。
在這些情況下,跟蹤將不適用,而
指令碼化是更好的選擇。如果您跟蹤這些模型,您可能會在後續呼叫模型時默默地獲得不正確的結果。跟蹤器會嘗試在執行可能導致生成錯誤跟蹤的操作時發出警告。- 引數
func (callable or torch.nn.Module) – 將使用 example_inputs 執行的 Python 函式或 torch.nn.Module。 func 的引數和返回值必須是張量,或包含張量的(可能巢狀的)元組。當將模組傳遞給 torch.jit.trace 時,僅執行和跟蹤
forward方法(有關詳細資訊,請參閱torch.jit.trace)。- 關鍵字引數
example_inputs (tuple or torch.Tensor or None, optional) – 在跟蹤時將傳遞給函式的示例輸入元組。預設為
None。應指定此引數或example_kwarg_inputs。生成的跟蹤可以與不同型別和形狀的輸入一起執行,前提是跟蹤的操作支援這些型別和形狀。 example_inputs 也可以是單個張量,在這種情況下它會自動包裝在元組中。當值為 None 時,應指定example_kwarg_inputs。check_trace (
bool, optional) – 檢查透過跟蹤程式碼執行的相同輸入是否產生相同的輸出。預設為True。如果您需要停用此選項,例如,如果您的網路包含非確定性操作,或者如果您確信網路是正確的(儘管檢查器失敗)。check_inputs (list of tuples, optional) – 用於將跟蹤與預期進行比較的一組輸入引數的元組列表。每個元組相當於在
example_inputs中指定的輸入引數集。為了獲得最佳結果,請傳入一組代表網路預期輸入的形狀和型別空間的檢查輸入。如果未指定,則使用原始example_inputs進行檢查。check_tolerance (float, optional) – 檢查器過程中使用的浮點數比較容差。在已知原因(例如,運算元融合)導致結果在數值上發生分歧的情況下,可以使用此選項來放寬檢查器的嚴格性。
strict (
bool, optional) – 以嚴格模式執行跟蹤器或不執行(預設為True)。僅當您希望跟蹤器記錄您的可變容器型別(當前是list/dict)並且您確信您在問題中使用的容器是constant結構並且不被用作控制流(if, for)條件時,才將其關閉。example_kwarg_inputs (dict, optional) – 此引數是在跟蹤時傳遞給函式的示例輸入的關鍵字引數包。預設為
None。應指定此引數或example_inputs。字典將透過跟蹤函式的引數名稱進行解包。如果字典的鍵與跟蹤函式的引數名稱不匹配,將引發執行時異常。
- 返回
如果 func 是 nn.Module 或 nn.Module 的
forward,則 trace 返回一個具有單個forward方法的ScriptModule物件,該方法包含跟蹤的程式碼。返回的 ScriptModule 將具有與原始nn.Module相同的子模組和引數集。如果func是獨立函式,則trace返回 ScriptFunction。
示例(跟蹤函式)
import torch def foo(x, y): return 2 * x + y # Run `foo` with the provided inputs and record the tensor operations traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3))) # `traced_foo` can now be run with the TorchScript interpreter or saved # and loaded in a Python-free environment
示例(跟蹤現有模組)
import torch import torch.nn as nn class Net(nn.Module): def __init__(self) -> None: super().__init__() self.conv = nn.Conv2d(1, 1, 3) def forward(self, x): return self.conv(x) n = Net() example_weight = torch.rand(1, 1, 3, 3) example_forward_input = torch.rand(1, 1, 3, 3) # Trace a specific method and construct `ScriptModule` with # a single `forward` method module = torch.jit.trace(n.forward, example_forward_input) # Trace a module (implicitly traces `forward`) and construct a # `ScriptModule` with a single `forward` method module = torch.jit.trace(n, example_forward_input)