get_graph_node_names¶
- torchvision.models.feature_extraction.get_graph_node_names(model: Module, tracer_kwargs: Optional[dict[str, Any]] = None, suppress_diff_warning: bool = False, concrete_args: Optional[dict[str, Any]] = None) tuple[list[str], list[str]][原始碼]¶
用於返回按執行順序排列的節點的開發實用程式。有關節點名稱的說明,請參閱
create_feature_extractor()。這對於檢視可用於特徵提取的節點名稱非常有用。節點名稱無法輕鬆從模型程式碼中直接讀取有兩個原因:並非所有子模組都會被追蹤。來自
torch.nn的模組都屬於此類。表示相同操作或葉子模組重複應用的節點會獲得一個
_{counter}字尾。
模型會追蹤兩次:一次在訓練模式下,一次在評估模式下。將返回兩個節點的名稱集。
有關此處使用的節點命名約定的更多詳細資訊,請參閱 相關小節,位於 文件 中。
- 引數:
model (nn.Module) – 我們想要列印節點名稱的模型
tracer_kwargs (dict, optional) – 用於
NodePathTracer的關鍵字引數字典(最終會傳遞給 torch.fx.Tracer)。預設情況下,它將包裝並使所有 torchvision 操作成為葉子節點:{“autowrap_modules”: (math, torchvision.ops,),”leaf_modules”: _get_leaf_modules_for_ops(),} 警告:如果使用者提供 tracer_kwargs,上述預設引數將附加到使用者提供的字典中。suppress_diff_warning (bool, optional) – 當訓練和評估版本圖存在差異時,是否抑制警告。預設為 False。
concrete_args (Optional[Dict[str, any]]) – 不應被視為代理的具體引數。根據 Pytorch 文件,此引數的 API 可能不被保證。
- 返回:
一個在訓練模式下追蹤模型得到的節點名稱列表,以及一個在評估模式下追蹤模型得到的節點名稱列表。
- 返回型別:
示例
>>> model = torchvision.models.resnet18() >>> train_nodes, eval_nodes = get_graph_node_names(model)