torch.jit.freeze#
- torch.jit.freeze(mod, preserved_attrs=None, optimize_numerics=True)[原始碼]#
凍結 ScriptModule,將子模組和屬性內聯為常量。
凍結一個
ScriptModule會克隆它,並嘗試將克隆模組的子模組、引數和屬性內聯為 TorchScript IR 圖中的常量。預設情況下,forward 方法將被保留,同時也會保留 preserved_attrs 中指定的屬性和方法。此外,在保留方法中修改的任何屬性都將被保留。凍結目前只接受處於 eval 模式的 ScriptModules。
凍結應用了通用最佳化,這將加速您的模型,無論是在何種機器上。為了使用特定伺服器設定進一步最佳化,請在凍結後執行 optimize_for_inference。
- 引數
mod (
ScriptModule) – 要凍結的模組preserved_attrs (Optional[List[str]]) – 除了 forward 方法之外,還要保留的屬性列表。在保留方法中修改的屬性也將被保留。
optimize_numerics (bool) – 如果為
True,則會執行一組不嚴格保留數值的最佳化。有關最佳化的詳細資訊,請參閱 torch.jit.run_frozen_optimizations。
- 返回
凍結的
ScriptModule。
示例 (凍結一個帶有 Parameter 的簡單模組)
def forward(self, input): output = self.weight.mm(input) output = self.linear(output) return output scripted_module = torch.jit.script(MyModule(2, 3).eval()) frozen_module = torch.jit.freeze(scripted_module) # parameters have been removed and inlined into the Graph as constants assert len(list(frozen_module.named_parameters())) == 0 # See the compiled graph as Python code print(frozen_module.code)
示例 (凍結一個帶有保留屬性的模組)
def forward(self, input): self.modified_tensor += 1 return input + self.modified_tensor scripted_module = torch.jit.script(MyModule2().eval()) frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"]) # we've manually preserved `version`, so it still exists on the frozen module and can be modified assert frozen_module.version == 1 frozen_module.version = 2 # `modified_tensor` is detected as being mutated in the forward, so freezing preserves # it to retain model semantics assert frozen_module(torch.tensor(1)) == torch.tensor(12) # now that we've run it once, the next result will be incremented by one assert frozen_module(torch.tensor(1)) == torch.tensor(13)
注意
也支援凍結子模組屬性:frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=[“submodule.version”])
注意
如果您不確定某個屬性為何未被內聯為常量,您可以在 frozen_module.forward.graph 上執行 dump_alias_db 來檢視凍結是否檢測到該屬性已被修改。
注意
由於凍結將權重轉換為常量並移除模組層次結構,因此 to 和其他用於操作裝置或 dtype 的 nn.Module 方法將不再起作用。作為一種變通方法,您可以在 torch.jit.load 中指定 map_location 來重新對映裝置,但是特定於裝置的邏輯可能已經嵌入到模型中。