torch.jit.script#
- torch.jit.script(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs=None)[source]#
編譯函式。
編譯一個函式或
nn.Module會檢查原始碼,使用 TorchScript 編譯器將其編譯為 TorchScript 程式碼,並返回一個ScriptModule或ScriptFunction。TorchScript 本身是 Python 語言的一個子集,所以並非所有的 Python 功能都可用,但我們提供了足夠的功能來處理張量並進行控制流操作。完整的指南請參見 TorchScript 語言參考。編譯字典或列表會將其中資料複製到 TorchScript 例項中,之後可以在 Python 和 TorchScript 之間零複製地傳遞。
torch.jit.script可以用作模組、函式、字典和列表的函式也可以用作 torchscript 類和函式的裝飾器
@torch.jit.script。
- 引數
obj (Callable, class, or nn.Module) – 要編譯的
nn.Module、函式、類型別、字典或列表。example_inputs (Union[List[Tuple], Dict[Callable, List[Tuple]], None]) – 提供示例輸入來註解函式或
nn.Module的引數。
- 返回
如果
obj是nn.Module,script返回一個ScriptModule物件。返回的ScriptModule將擁有與原始nn.Module相同的子模組和引數。如果obj是一個獨立的函式,則會返回一個ScriptFunction。如果obj是一個dict,則script返回 torch._C.ScriptDict 的一個例項。如果obj是一個list,則script返回 torch._C.ScriptList 的一個例項。
- 編譯函式
使用
@torch.jit.script裝飾器將透過編譯函式體來構造一個ScriptFunction。示例(編譯函式)
import torch @torch.jit.script def foo(x, y): if x.max() > y.max(): r = x else: r = y return r print(type(foo)) # torch.jit.ScriptFunction # See the compiled graph as Python code print(foo.code) # Call the function using the TorchScript interpreter foo(torch.ones(2, 2), torch.ones(2, 2))
- **使用 example_inputs 編譯函式
示例輸入可用於註解函式引數。
示例(在編譯前註解函式)
import torch def test_sum(a, b): return a + b # Annotate the arguments to be int scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)]) print(type(scripted_fn)) # torch.jit.ScriptFunction # See the compiled graph as Python code print(scripted_fn.code) # Call the function using the TorchScript interpreter scripted_fn(20, 100)
- 編譯 nn.Module
預設情況下,編譯
nn.Module將編譯forward方法,並遞迴地編譯forward呼叫到的任何方法、子模組和函式。如果nn.Module只使用了 TorchScript 支援的特性,則無需修改原始模組程式碼。script將構造一個ScriptModule,其中包含原始模組的屬性、引數和方法的副本。示例(編譯一個帶有 Parameter 的簡單模組)
import torch class MyModule(torch.nn.Module): def __init__(self, N, M): super().__init__() # This parameter will be copied to the new ScriptModule self.weight = torch.nn.Parameter(torch.rand(N, M)) # When this submodule is used, it will be compiled self.linear = torch.nn.Linear(N, M) def forward(self, input): output = self.weight.mv(input) # This calls the `forward` method of the `nn.Linear` module, which will # cause the `self.linear` submodule to be compiled to a `ScriptModule` here output = self.linear(output) return output scripted_module = torch.jit.script(MyModule(2, 3))
示例(編譯一個帶有被跟蹤子模組的模組)
import torch import torch.nn as nn import torch.nn.functional as F class MyModule(nn.Module): def __init__(self) -> None: super().__init__() # torch.jit.trace produces a ScriptModule's conv1 and conv2 self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16)) self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16)) def forward(self, input): input = F.relu(self.conv1(input)) input = F.relu(self.conv2(input)) return input scripted_module = torch.jit.script(MyModule())
要編譯
forward以外的方法(並遞迴地編譯它呼叫的任何內容),請將@torch.jit.export裝飾器新增到該方法。要選擇退出編譯,請使用@torch.jit.ignore或@torch.jit.unused。示例(模組中被匯出和忽略的方法)
import torch import torch.nn as nn class MyModule(nn.Module): def __init__(self) -> None: super().__init__() @torch.jit.export def some_entry_point(self, input): return input + 10 @torch.jit.ignore def python_only_fn(self, input): # This function won't be compiled, so any # Python APIs can be used import pdb pdb.set_trace() def forward(self, input): if self.training: self.python_only_fn(input) return input * 99 scripted_module = torch.jit.script(MyModule()) print(scripted_module.some_entry_point(torch.randn(2, 2))) print(scripted_module(torch.randn(2, 2)))
示例(使用 example_inputs 註解 nn.Module 的 forward 方法)
import torch import torch.nn as nn from typing import NamedTuple class MyModule(NamedTuple): result: List[int] class TestNNModule(torch.nn.Module): def forward(self, a) -> MyModule: result = MyModule(result=a) return result pdt_model = TestNNModule() # Runs the pdt_model in eager model with the inputs provided and annotates the arguments of forward scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], }) # Run the scripted_model with actual inputs print(scripted_model([20]))