模型檢查的特徵提取¶
torchvision.models.feature_extraction 包包含特徵提取工具,讓我們能夠深入模型以訪問輸入的中間轉換。這可能對計算機視覺中的各種應用有用。舉幾個例子:
視覺化特徵圖。
提取特徵以計算影像描述符,用於人臉識別、複製檢測或影像檢索等任務。
將選定的特徵傳遞給下游子網路,以便針對特定任務進行端到端訓練。例如,將特徵層次結構傳遞給帶有物件檢測頭的特徵金字塔網路。
Torchvision 提供 create_feature_extractor() 來實現此目的。它透過大致遵循以下步驟來實現:
符號化跟蹤模型,以圖形化方式表示其如何逐步轉換輸入。
將使用者選擇的圖節點設定為輸出。
刪除所有冗餘節點(輸出節點下游的任何內容)。
從生成的圖中生成 Python 程式碼,並將其與圖本身一起打包到一個 PyTorch 模組中。
torch.fx 文件提供了關於上述過程和符號化跟蹤內部工作原理的更通用和詳細的解釋。
關於節點名稱
為了指定哪些節點應作為提取特徵的輸出節點,應該熟悉這裡使用的節點命名約定(它與torch.fx 中使用的略有不同)。節點名稱指定為透過模組層次結構從頂層模組向下到葉操作或葉模組的. 分隔的路徑。例如,ResNet-50 中的"layer4.2.relu" 表示ResNet 模組的第 4 層第 2 個塊的 ReLU 的輸出。以下是一些需要牢記的要點:
在為
create_feature_extractor()指定節點名稱時,可以提供節點名稱的截斷版本作為快捷方式。要了解其工作原理,請嘗試建立 ResNet-50 模型並使用train_nodes, _ = get_graph_node_names(model) print(train_nodes)列印節點名稱,並觀察與layer4相關的最後一個節點是"layer4.2.relu_2"。可以將"layer4.2.relu_2"指定為返回節點,或者按約定只指定"layer4",因為它指的是layer4的最後一個節點(按執行順序)。如果某個模組或操作被重複使用了一次以上,節點名稱會新增一個額外的
_{int}字尾以消除歧義。例如,可能在同一個forward方法中使用了三次加法(+)運算。那麼將會有"path.to.module.add"、"path.to.module.add_1"、"path.to.module.add_2"。計數器在直接父節點的範圍內維護。因此,在 ResNet-50 中有一個"layer4.1.add"和一個"layer4.2.add"。因為加法運算位於不同的塊中,所以不需要字尾來消除歧義。
示例
以下是我們如何為 MaskRCNN 提取特徵的示例
import torch
from torchvision.models import resnet50
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.detection.mask_rcnn import MaskRCNN
from torchvision.models.detection.backbone_utils import LastLevelMaxPool
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork
# To assist you in designing the feature extractor you may want to print out
# the available nodes for resnet50.
m = resnet50()
train_nodes, eval_nodes = get_graph_node_names(resnet50())
# The lists returned, are the names of all the graph nodes (in order of
# execution) for the input model traced in train mode and in eval mode
# respectively. You'll find that `train_nodes` and `eval_nodes` are the same
# for this example. But if the model contains control flow that's dependent
# on the training mode, they may be different.
# To specify the nodes you want to extract, you could select the final node
# that appears in each of the main layers:
return_nodes = {
# node_name: user-specified key for output dict
'layer1.2.relu_2': 'layer1',
'layer2.3.relu_2': 'layer2',
'layer3.5.relu_2': 'layer3',
'layer4.2.relu_2': 'layer4',
}
# But `create_feature_extractor` can also accept truncated node specifications
# like "layer1", as it will just pick the last node that's a descendent of
# of the specification. (Tip: be careful with this, especially when a layer
# has multiple outputs. It's not always guaranteed that the last operation
# performed is the one that corresponds to the output you desire. You should
# consult the source code for the input model to confirm.)
return_nodes = {
'layer1': 'layer1',
'layer2': 'layer2',
'layer3': 'layer3',
'layer4': 'layer4',
}
# Now you can build the feature extractor. This returns a module whose forward
# method returns a dictionary like:
# {
# 'layer1': output of layer 1,
# 'layer2': output of layer 2,
# 'layer3': output of layer 3,
# 'layer4': output of layer 4,
# }
create_feature_extractor(m, return_nodes=return_nodes)
# Let's put all that together to wrap resnet50 with MaskRCNN
# MaskRCNN requires a backbone with an attached FPN
class Resnet50WithFPN(torch.nn.Module):
def __init__(self):
super(Resnet50WithFPN, self).__init__()
# Get a resnet50 backbone
m = resnet50()
# Extract 4 main layers (note: MaskRCNN needs this particular name
# mapping for return nodes)
self.body = create_feature_extractor(
m, return_nodes={f'layer{k}': str(v)
for v, k in enumerate([1, 2, 3, 4])})
# Dry run to get number of channels for FPN
inp = torch.randn(2, 3, 224, 224)
with torch.no_grad():
out = self.body(inp)
in_channels_list = [o.shape[1] for o in out.values()]
# Build FPN
self.out_channels = 256
self.fpn = FeaturePyramidNetwork(
in_channels_list, out_channels=self.out_channels,
extra_blocks=LastLevelMaxPool())
def forward(self, x):
x = self.body(x)
x = self.fpn(x)
return x
# Now we can build our model!
model = MaskRCNN(Resnet50WithFPN(), num_classes=91).eval()
API 參考¶
|
建立一個新的圖模組,它將給定模型中的中間節點作為字典返回,其中鍵是使用者指定的字串,值是請求的輸出。 |
|
開發實用程式,按執行順序返回節點名稱。 |