torch.jit.save#
- torch.jit.save(m, f, _extra_files=None)[source]#
儲存該模組的離線版本,以供在單獨的程序中使用。
儲存的模組會序列化該模組的所有方法、子模組、引數和屬性。它可以使用 C++ API 透過
torch::jit::load(filename)載入,或者使用 Python API 中的torch.jit.load載入。要能夠儲存模組,它不能呼叫任何本地 Python 函式。這意味著所有子模組都必須繼承自
ScriptModule。危險
所有模組,無論其裝置如何,在載入時始終會載入到 CPU。這與
torch.load()的語義不同,並且將來可能會發生變化。- 引數
m – 要儲存的
ScriptModule。f – 一個類檔案物件(必須實現 write 和 flush 方法)或包含檔名的字串。
_extra_files – 檔名到內容的對映,這些內容將作為 f 的一部分儲存。
注意
torch.jit.save 嘗試在不同版本之間保留某些運算元的行為。例如,在 PyTorch 1.5 中,兩個整數張量相除執行的是整除,如果包含該程式碼的模組在 PyTorch 1.5 中儲存並在 PyTorch 1.6 中載入,其除法行為將得以保留。然而,在 PyTorch 1.6 中儲存的相同模組將無法在 PyTorch 1.5 中載入,因為 1.6 中除法的行為發生了變化,而 1.5 無法複製 1.6 的行為。
示例: .. testcode
import torch import io class MyModule(torch.nn.Module): def forward(self, x): return x + 10 m = torch.jit.script(MyModule()) # Save to file torch.jit.save(m, 'scriptmodule.pt') # This line is equivalent to the previous m.save("scriptmodule.pt") # Save to io.BytesIO buffer buffer = io.BytesIO() torch.jit.save(m, buffer) # Save with extra files extra_files = {'foo.txt': b'bar'} torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)