快捷方式

Transforms v2:端到端目標檢測/分割示例

注意

Colab 上嘗試,或 轉到末尾 下載完整的示例程式碼。

目標檢測和分割任務得到原生支援:torchvision.transforms.v2 可以同時轉換影像、影片、邊界框和掩碼。

本示例展示了一個端到端的例項分割訓練案例,使用了來自 torchvision.datasetstorchvision.modelstorchvision.transforms.v2 的 Torchvision 工具。這裡介紹的所有內容都可以類似地應用於目標檢測或語義分割任務。

import pathlib

import torch
import torch.utils.data

from torchvision import models, datasets, tv_tensors
from torchvision.transforms import v2

torch.manual_seed(0)

# This loads fake data for illustration purposes of this example. In practice, you'll have
# to replace this with the proper data.
# If you're trying to run that on Colab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
ROOT = pathlib.Path("../assets") / "coco"
IMAGES_PATH = str(ROOT / "images")
ANNOTATIONS_PATH = str(ROOT / "instances.json")
from helpers import plot

資料集準備

我們首先載入 CocoDetection 資料集,看看它目前返回什麼。

dataset = datasets.CocoDetection(IMAGES_PATH, ANNOTATIONS_PATH)

sample = dataset[0]
img, target = sample
print(f"{type(img) = }\n{type(target) = }\n{type(target[0]) = }\n{target[0].keys() = }")
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
type(img) = <class 'PIL.Image.Image'>
type(target) = <class 'list'>
type(target[0]) = <class 'dict'>
target[0].keys() = dict_keys(['segmentation', 'iscrowd', 'image_id', 'bbox', 'category_id', 'id'])

Torchvision 資料集會保留資料集作者最初設定的資料結構和型別。因此,預設情況下,輸出結構可能並不總是與模型或變換相容。

為了克服這個問題,我們可以使用 wrap_dataset_for_transforms_v2() 函式。對於 CocoDetection,這會將目標結構更改為單個字典列表。

dataset = datasets.wrap_dataset_for_transforms_v2(dataset, target_keys=("boxes", "labels", "masks"))

sample = dataset[0]
img, target = sample
print(f"{type(img) = }\n{type(target) = }\n{target.keys() = }")
print(f"{type(target['boxes']) = }\n{type(target['labels']) = }\n{type(target['masks']) = }")
type(img) = <class 'PIL.Image.Image'>
type(target) = <class 'dict'>
target.keys() = dict_keys(['boxes', 'masks', 'labels'])
type(target['boxes']) = <class 'torchvision.tv_tensors._bounding_boxes.BoundingBoxes'>
type(target['labels']) = <class 'torch.Tensor'>
type(target['masks']) = <class 'torchvision.tv_tensors._mask.Mask'>

我們使用了 target_keys 引數來指定我們感興趣的輸出型別。現在,我們的資料集返回一個字典作為目標,其中值是 TVTensor(它們都是 torch.Tensor 的子類)。我們丟棄了之前輸出中所有不必要的鍵,但如果你需要任何原始鍵(例如“image_id”),仍然可以要求它們。

注意

如果你只想進行檢測,則不需要也不應該在 target_keys 中傳遞“masks”:如果樣本中存在掩碼,它們也會被轉換,不必要地減慢你的轉換速度。

作為基線,讓我們先看一下沒有轉換的樣本。

plot([dataset[0], dataset[1]])
plot transforms e2e

變換 (Transforms)

現在,讓我們定義我們的預處理變換。所有變換都知道如何在相關時處理影像、邊界框和掩碼。

變換通常作為資料集的 transforms 引數傳遞,以便它們可以利用 torch.utils.data.DataLoader 的多程序功能。

transforms = v2.Compose(
    [
        v2.ToImage(),
        v2.RandomPhotometricDistort(p=1),
        v2.RandomZoomOut(fill={tv_tensors.Image: (123, 117, 104), "others": 0}),
        v2.RandomIoUCrop(),
        v2.RandomHorizontalFlip(p=1),
        v2.SanitizeBoundingBoxes(),
        v2.ToDtype(torch.float32, scale=True),
    ]
)

dataset = datasets.CocoDetection(IMAGES_PATH, ANNOTATIONS_PATH, transforms=transforms)
dataset = datasets.wrap_dataset_for_transforms_v2(dataset, target_keys=["boxes", "labels", "masks"])
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!

這裡有幾點值得注意:

  • 我們將 PIL 影像轉換為 Image 物件。這並非嚴格必需,但依賴張量(此處為張量子類)通常會更快。

  • 我們呼叫 SanitizeBoundingBoxes 來確保我們移除退化的邊界框以及它們對應的標籤和掩碼。 SanitizeBoundingBoxes 應該至少放置一次在檢測管道的末尾;如果使用了 RandomIoUCrop,這一點尤其重要。

讓我們看看應用我們的增強管道後的樣本。

plot([dataset[0], dataset[1]])
plot transforms e2e

我們可以看到影像的顏色發生了扭曲、縮放和翻轉。邊界框和掩碼也相應地得到了轉換。無需更多廢話,我們就可以開始訓練了。

資料載入和訓練迴圈

下面我們使用 Mask-RCNN,這是一個例項分割模型,但本教程中介紹的所有內容也適用於目標檢測和語義分割任務。

data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=2,
    # We need a custom collation function here, since the object detection
    # models expect a sequence of images and target dictionaries. The default
    # collation function tries to torch.stack() the individual elements,
    # which fails in general for object detection, because the number of bounding
    # boxes varies between the images of the same batch.
    collate_fn=lambda batch: tuple(zip(*batch)),
)

model = models.get_model("maskrcnn_resnet50_fpn_v2", weights=None, weights_backbone=None).train()

for imgs, targets in data_loader:
    loss_dict = model(imgs, targets)
    # Put your training logic here

    print(f"{[img.shape for img in imgs] = }")
    print(f"{[type(target) for target in targets] = }")
    for name, loss_val in loss_dict.items():
        print(f"{name:<20}{loss_val:.3f}")
[img.shape for img in imgs] = [torch.Size([3, 512, 512]), torch.Size([3, 409, 493])]
[type(target) for target in targets] = [<class 'dict'>, <class 'dict'>]
loss_classifier     4.722
loss_box_reg        0.006
loss_mask           0.734
loss_objectness     0.691
loss_rpn_box_reg    0.036

訓練參考

在那裡,你可以檢視 torchvision 參考資料,其中包含我們用於訓練模型的實際訓練指令碼。

免責宣告:我們參考資料中的程式碼比你自己的用例所需更復雜:這是因為我們支援不同的後端(PIL、張量、TVTensors)和不同的變換名稱空間(v1 和 v2)。因此,不要害怕簡化,只保留你需要的部分。

指令碼總執行時間: (0 分鐘 4.966 秒)

由 Sphinx-Gallery 生成的畫廊

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源