fuse_modules#
- class torch.ao.quantization.fuse_modules.fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=<function fuse_known_modules>, fuse_custom_config_dict=None)[原始碼]#
將一系列模組融合為單個模組。
僅融合以下模組序列:conv, bn conv, bn, relu conv, relu linear, relu bn, relu 所有其他序列保持不變。對於這些序列,將列表中的第一個模組替換為融合後的模組,其餘模組替換為identity。
- 引數
model – 包含待融合模組的模型
modules_to_fuse – 要融合的模組名稱列表。如果只有一個模組列表要融合,也可以是字串列表。
inplace – 指定融合是否在模型上原地進行,預設為返回一個新模型
fuser_func – 一個函式,接收一個模組列表並輸出一個相同長度的融合模組列表。例如,fuser_func([convModule, BNModule]) 返回列表 [ConvBNModule, nn.Identity()] 預設為 torch.ao.quantization.fuse_known_modules
fuse_custom_config_dict – 融合的自定義配置
# Example of fuse_custom_config_dict fuse_custom_config_dict = { # Additional fuser_method mapping "additional_fuser_method_mapping": { (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn }, }
- 返回
融合模組後的模型。如果 inplace=True,則會建立一個新副本。
示例
>>> m = M().eval() >>> # m is a module containing the sub-modules below >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']] >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) >>> output = fused_m(input) >>> m = M().eval() >>> # Alternately provide a single list of modules to fuse >>> modules_to_fuse = ['conv1', 'bn1', 'relu1'] >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) >>> output = fused_m(input)