快捷方式

create_feature_extractor

torchvision.models.feature_extraction.create_feature_extractor(model: Module, return_nodes: Optional[Union[list[str], dict[str, str]]] = None, train_return_nodes: Optional[Union[list[str], dict[str, str]]] = None, eval_return_nodes: Optional[Union[list[str], dict[str, str]]] = None, tracer_kwargs: Optional[dict[str, Any]] = None, suppress_diff_warning: bool = False, concrete_args: Optional[dict[str, Any]] = None) GraphModule[原始碼]

建立一個新的圖模組,該模組從給定的模型返回中間節點,作為一個字典,其中使用者指定的鍵是字串,值是要提取的節點。這是透過 FX 重寫模型計算圖來實現的,以便將所需的節點作為輸出返回。所有未使用的節點及其對應的引數都將被刪除。

期望的輸出節點必須指定為一個以 . 分隔的路徑,該路徑從頂層模組向下遍歷到葉操作或葉模組。有關此處使用的節點命名約定的更多詳細資訊,請參閱 相關小節,位於 文件 中。

並非所有模型都可以透過 FX 進行跟蹤,儘管稍加調整可以使其相容。以下是一些(非詳盡的)提示:

  • 如果您不需要跟蹤某個特定的、有問題的子模組,可以透過將 leaf_modules 列表作為 tracer_kwargs 的一個引數來將其變成“葉模組”(參見下面的示例)。它不會被跟蹤,而是生成的圖將引用該模組的 forward 方法。

  • 同樣,您可以透過將 autowrap_functions 列表作為 tracer_kwargs 的一個引數來將函式變成葉函式(參見下面的示例)。

  • 一些內建的 Python 函式可能會有問題。例如,int 在跟蹤時會引發錯誤。您可以將其包裝在自己的函式中,然後將其作為 autowrap_functions 的引數傳遞給 tracer_kwargs

有關 FX 的更多資訊,請參閱 torch.fx 文件

引數:
  • model (nn.Module) – 將要從中提取特徵的模型

  • return_nodes (listdict, 可選) – ListDict,包含要返回其啟用值的節點名稱(或部分名稱 - 參見上方註釋)。如果為 Dict,則鍵是節點名稱,值是圖模組返回字典的使用者指定鍵。如果為 List,則將其視為直接對映節點規範字符串到輸出名稱的 Dict。如果同時指定了 train_return_nodeseval_return_nodes,則不應指定此引數。

  • train_return_nodes (listdict, 可選) – 類似於 return_nodes。如果訓練模式和評估模式的返回節點不同,可以使用此引數。如果指定了此引數,則必須同時指定 eval_return_nodes,並且不應指定 return_nodes

  • eval_return_nodes (listdict, 可選) – 類似於 return_nodes。如果訓練模式和評估模式的返回節點不同,可以使用此引數。如果指定了此引數,則必須同時指定 train_return_nodes,並且不應指定 return_nodes

  • tracer_kwargs (dict, 可選) – 傳遞給 NodePathTracer(它會將引數傳遞給其父類 torch.fx.Tracer)的關鍵字引數字典。預設情況下,它將包裝所有 torchvision 操作並將其設為葉節點:{“autowrap_modules”: (math, torchvision.ops,),”leaf_modules”: _get_leaf_modules_for_ops(),} 警告:如果使用者提供了 tracer_kwargs,上述預設引數將附加到使用者提供的字典中。

  • suppress_diff_warning (bool, 可選) – 當訓練和評估圖之間存在差異時,是否抑制警告。預設為 False。

  • concrete_args (Optional[Dict[str, any]]) – 不應被視為代理(Proxy)的具體引數。根據 Pytorch 文件,此引數的 API 可能不保證。

示例

>>> # Feature extraction with resnet
>>> model = torchvision.models.resnet18()
>>> # extract layer1 and layer3, giving as names `feat1` and feat2`
>>> model = create_feature_extractor(
>>>     model, {'layer1': 'feat1', 'layer3': 'feat2'})
>>> out = model(torch.rand(1, 3, 224, 224))
>>> print([(k, v.shape) for k, v in out.items()])
>>>     [('feat1', torch.Size([1, 64, 56, 56])),
>>>      ('feat2', torch.Size([1, 256, 14, 14]))]

>>> # Specifying leaf modules and leaf functions
>>> def leaf_function(x):
>>>     # This would raise a TypeError if traced through
>>>     return int(x)
>>>
>>> class LeafModule(torch.nn.Module):
>>>     def forward(self, x):
>>>         # This would raise a TypeError if traced through
>>>         int(x.shape[0])
>>>         return torch.nn.functional.relu(x + 4)
>>>
>>> class MyModule(torch.nn.Module):
>>>     def __init__(self):
>>>         super().__init__()
>>>         self.conv = torch.nn.Conv2d(3, 1, 3)
>>>         self.leaf_module = LeafModule()
>>>
>>>     def forward(self, x):
>>>         leaf_function(x.shape[0])
>>>         x = self.conv(x)
>>>         return self.leaf_module(x)
>>>
>>> model = create_feature_extractor(
>>>     MyModule(), return_nodes=['leaf_module'],
>>>     tracer_kwargs={'leaf_modules': [LeafModule],
>>>                    'autowrap_functions': [leaf_function]})

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源