create_feature_extractor¶
- torchvision.models.feature_extraction.create_feature_extractor(model: Module, return_nodes: Optional[Union[list[str], dict[str, str]]] = None, train_return_nodes: Optional[Union[list[str], dict[str, str]]] = None, eval_return_nodes: Optional[Union[list[str], dict[str, str]]] = None, tracer_kwargs: Optional[dict[str, Any]] = None, suppress_diff_warning: bool = False, concrete_args: Optional[dict[str, Any]] = None) GraphModule[原始碼]¶
- 建立一個新的圖模組,該模組從給定的模型返回中間節點,作為一個字典,其中使用者指定的鍵是字串,值是要提取的節點。這是透過 FX 重寫模型計算圖來實現的,以便將所需的節點作為輸出返回。所有未使用的節點及其對應的引數都將被刪除。 - 期望的輸出節點必須指定為一個以 - .分隔的路徑,該路徑從頂層模組向下遍歷到葉操作或葉模組。有關此處使用的節點命名約定的更多詳細資訊,請參閱 相關小節,位於 文件 中。- 並非所有模型都可以透過 FX 進行跟蹤,儘管稍加調整可以使其相容。以下是一些(非詳盡的)提示: - 如果您不需要跟蹤某個特定的、有問題的子模組,可以透過將 - leaf_modules列表作為- tracer_kwargs的一個引數來將其變成“葉模組”(參見下面的示例)。它不會被跟蹤,而是生成的圖將引用該模組的- forward方法。
- 同樣,您可以透過將 - autowrap_functions列表作為- tracer_kwargs的一個引數來將函式變成葉函式(參見下面的示例)。
- 一些內建的 Python 函式可能會有問題。例如, - int在跟蹤時會引發錯誤。您可以將其包裝在自己的函式中,然後將其作為- autowrap_functions的引數傳遞給- tracer_kwargs。
 - 有關 FX 的更多資訊,請參閱 torch.fx 文件。 - 引數:
- model (nn.Module) – 將要從中提取特徵的模型 
- return_nodes (list 或 dict, 可選) – - List或- Dict,包含要返回其啟用值的節點名稱(或部分名稱 - 參見上方註釋)。如果為- Dict,則鍵是節點名稱,值是圖模組返回字典的使用者指定鍵。如果為- List,則將其視為直接對映節點規範字符串到輸出名稱的- Dict。如果同時指定了- train_return_nodes和- eval_return_nodes,則不應指定此引數。
- train_return_nodes (list 或 dict, 可選) – 類似於 - return_nodes。如果訓練模式和評估模式的返回節點不同,可以使用此引數。如果指定了此引數,則必須同時指定- eval_return_nodes,並且不應指定- return_nodes。
- eval_return_nodes (list 或 dict, 可選) – 類似於 - return_nodes。如果訓練模式和評估模式的返回節點不同,可以使用此引數。如果指定了此引數,則必須同時指定- train_return_nodes,並且不應指定 return_nodes。
- tracer_kwargs (dict, 可選) – 傳遞給 - NodePathTracer(它會將引數傳遞給其父類 torch.fx.Tracer)的關鍵字引數字典。預設情況下,它將包裝所有 torchvision 操作並將其設為葉節點:{“autowrap_modules”: (math, torchvision.ops,),”leaf_modules”: _get_leaf_modules_for_ops(),} 警告:如果使用者提供了 tracer_kwargs,上述預設引數將附加到使用者提供的字典中。
- suppress_diff_warning (bool, 可選) – 當訓練和評估圖之間存在差異時,是否抑制警告。預設為 False。 
- concrete_args (Optional[Dict[str, any]]) – 不應被視為代理(Proxy)的具體引數。根據 Pytorch 文件,此引數的 API 可能不保證。 
 
 - 示例 - >>> # Feature extraction with resnet >>> model = torchvision.models.resnet18() >>> # extract layer1 and layer3, giving as names `feat1` and feat2` >>> model = create_feature_extractor( >>> model, {'layer1': 'feat1', 'layer3': 'feat2'}) >>> out = model(torch.rand(1, 3, 224, 224)) >>> print([(k, v.shape) for k, v in out.items()]) >>> [('feat1', torch.Size([1, 64, 56, 56])), >>> ('feat2', torch.Size([1, 256, 14, 14]))] >>> # Specifying leaf modules and leaf functions >>> def leaf_function(x): >>> # This would raise a TypeError if traced through >>> return int(x) >>> >>> class LeafModule(torch.nn.Module): >>> def forward(self, x): >>> # This would raise a TypeError if traced through >>> int(x.shape[0]) >>> return torch.nn.functional.relu(x + 4) >>> >>> class MyModule(torch.nn.Module): >>> def __init__(self): >>> super().__init__() >>> self.conv = torch.nn.Conv2d(3, 1, 3) >>> self.leaf_module = LeafModule() >>> >>> def forward(self, x): >>> leaf_function(x.shape[0]) >>> x = self.conv(x) >>> return self.leaf_module(x) >>> >>> model = create_feature_extractor( >>> MyModule(), return_nodes=['leaf_module'], >>> tracer_kwargs={'leaf_modules': [LeafModule], >>> 'autowrap_functions': [leaf_function]})