評價此頁

torch.jit.script#

torch.jit.script(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs=None)[source]#

編譯函式。

編譯一個函式或 nn.Module 會檢查原始碼,使用 TorchScript 編譯器將其編譯為 TorchScript 程式碼,並返回一個 ScriptModuleScriptFunction。TorchScript 本身是 Python 語言的一個子集,所以並非所有的 Python 功能都可用,但我們提供了足夠的功能來處理張量並進行控制流操作。完整的指南請參見 TorchScript 語言參考

編譯字典或列表會將其中資料複製到 TorchScript 例項中,之後可以在 Python 和 TorchScript 之間零複製地傳遞。

torch.jit.script 可以用作模組、函式、字典和列表的函式

也可以用作 torchscript 類和函式的裝飾器 @torch.jit.script

引數
  • obj (Callable, class, or nn.Module) – 要編譯的 nn.Module、函式、類型別、字典或列表。

  • example_inputs (Union[List[Tuple], Dict[Callable, List[Tuple]], None]) – 提供示例輸入來註解函式或 nn.Module 的引數。

返回

如果 objnn.Modulescript 返回一個 ScriptModule 物件。返回的 ScriptModule 將擁有與原始 nn.Module 相同的子模組和引數。如果 obj 是一個獨立的函式,則會返回一個 ScriptFunction。如果 obj 是一個 dict,則 script 返回 torch._C.ScriptDict 的一個例項。如果 obj 是一個 list,則 script 返回 torch._C.ScriptList 的一個例項。

編譯函式

使用 @torch.jit.script 裝飾器將透過編譯函式體來構造一個 ScriptFunction

示例(編譯函式)

import torch

@torch.jit.script
def foo(x, y):
    if x.max() > y.max():
        r = x
    else:
        r = y
    return r

print(type(foo))  # torch.jit.ScriptFunction

# See the compiled graph as Python code
print(foo.code)

# Call the function using the TorchScript interpreter
foo(torch.ones(2, 2), torch.ones(2, 2))
**使用 example_inputs 編譯函式

示例輸入可用於註解函式引數。

示例(在編譯前註解函式)

import torch

def test_sum(a, b):
    return a + b

# Annotate the arguments to be int
scripted_fn = torch.jit.script(test_sum, example_inputs=[(3, 4)])

print(type(scripted_fn))  # torch.jit.ScriptFunction

# See the compiled graph as Python code
print(scripted_fn.code)

# Call the function using the TorchScript interpreter
scripted_fn(20, 100)
編譯 nn.Module

預設情況下,編譯 nn.Module 將編譯 forward 方法,並遞迴地編譯 forward 呼叫到的任何方法、子模組和函式。如果 nn.Module 只使用了 TorchScript 支援的特性,則無需修改原始模組程式碼。script 將構造一個 ScriptModule,其中包含原始模組的屬性、引數和方法的副本。

示例(編譯一個帶有 Parameter 的簡單模組)

import torch

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super().__init__()
        # This parameter will be copied to the new ScriptModule
        self.weight = torch.nn.Parameter(torch.rand(N, M))

        # When this submodule is used, it will be compiled
        self.linear = torch.nn.Linear(N, M)

    def forward(self, input):
        output = self.weight.mv(input)

        # This calls the `forward` method of the `nn.Linear` module, which will
        # cause the `self.linear` submodule to be compiled to a `ScriptModule` here
        output = self.linear(output)
        return output

scripted_module = torch.jit.script(MyModule(2, 3))

示例(編譯一個帶有被跟蹤子模組的模組)

import torch
import torch.nn as nn
import torch.nn.functional as F

class MyModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        # torch.jit.trace produces a ScriptModule's conv1 and conv2
        self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
        self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))

    def forward(self, input):
        input = F.relu(self.conv1(input))
        input = F.relu(self.conv2(input))
        return input

scripted_module = torch.jit.script(MyModule())

要編譯 forward 以外的方法(並遞迴地編譯它呼叫的任何內容),請將 @torch.jit.export 裝飾器新增到該方法。要選擇退出編譯,請使用 @torch.jit.ignore@torch.jit.unused

示例(模組中被匯出和忽略的方法)

import torch
import torch.nn as nn


class MyModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    @torch.jit.export
    def some_entry_point(self, input):
        return input + 10

    @torch.jit.ignore
    def python_only_fn(self, input):
        # This function won't be compiled, so any
        # Python APIs can be used
        import pdb

        pdb.set_trace()

    def forward(self, input):
        if self.training:
            self.python_only_fn(input)
        return input * 99


scripted_module = torch.jit.script(MyModule())
print(scripted_module.some_entry_point(torch.randn(2, 2)))
print(scripted_module(torch.randn(2, 2)))

示例(使用 example_inputs 註解 nn.Module 的 forward 方法)

import torch
import torch.nn as nn
from typing import NamedTuple

class MyModule(NamedTuple):
result: List[int]

class TestNNModule(torch.nn.Module):
    def forward(self, a) -> MyModule:
        result = MyModule(result=a)
        return result

pdt_model = TestNNModule()

# Runs the pdt_model in eager model with the inputs provided and annotates the arguments of forward
scripted_model = torch.jit.script(pdt_model, example_inputs={pdt_model: [([10, 20, ], ), ], })

# Run the scripted_model with actual inputs
print(scripted_model([20]))