torch.jit.fork#
- torch.jit.fork(func, *args, **kwargs)[原始碼]#
建立一個非同步任務來執行 func,並返回一個指向該執行結果值的引用。
fork 會立即返回,因此 func 的返回值可能尚未計算。要強制完成任務並訪問返回值,請對 Future 呼叫 torch.jit.wait。fork 使用返回 T 的 func 呼叫時,型別為 torch.jit.Future[T]。fork 呼叫可以任意巢狀,並可使用位置引數和關鍵字引數呼叫。非同步執行僅在 TorchScript 中執行時發生。如果在純 Python 中執行,fork 不會並行執行。fork 在跟蹤時呼叫也不會並行執行,但 fork 和 wait 呼叫將被捕獲在匯出的 IR 圖中。
警告
fork 任務將非確定性地執行。我們建議僅為不修改其輸入、模組屬性或全域性狀態的純函式生成並行 fork 任務。
- 引數
func (callable 或 torch.nn.Module) – 將被呼叫的 Python 函式或 torch.nn.Module。如果在 TorchScript 中執行,它將非同步執行,否則不會。跟蹤的 fork 呼叫將被捕獲在 IR 中。
*args – 呼叫 func 時使用的引數。
**kwargs – 呼叫 func 時使用的引數。
- 返回
一個指向 func 執行的引用。值 T 只能透過 torch.jit.wait 強制完成 func 來訪問。
- 返回型別
torch.jit.Future[T]
示例(fork 一個自由函式)
import torch from torch import Tensor def foo(a: Tensor, b: int) -> Tensor: return a + b def bar(a): fut: torch.jit.Future[Tensor] = torch.jit.fork(foo, a, b=2) return torch.jit.wait(fut) script_bar = torch.jit.script(bar) input = torch.tensor(2) # only the scripted version executes asynchronously assert script_bar(input) == bar(input) # trace is not run asynchronously, but fork is captured in IR graph = torch.jit.trace(bar, (input,)).graph assert "fork" in str(graph)
示例(fork 一個模組方法)
import torch from torch import Tensor class AddMod(torch.nn.Module): def forward(self, a: Tensor, b: int): return a + b class Mod(torch.nn.Module): def __init__(self) -> None: super(self).__init__() self.mod = AddMod() def forward(self, input): fut = torch.jit.fork(self.mod, a, b=2) return torch.jit.wait(fut) input = torch.tensor(2) mod = Mod() assert mod(input) == torch.jit.script(mod).forward(input)