快捷方式

入門 transforms v2

注意

Colab 上試用,或 轉到末尾 下載完整的示例程式碼。

此示例說明了您需要了解的關於新 torchvision.transforms.v2 API 的所有內容。我們將介紹影像分類等簡單任務,以及物件檢測/分割等更高階的任務。

首先,做一些準備工作

from pathlib import Path
import torch
import matplotlib.pyplot as plt
plt.rcParams["savefig.bbox"] = 'tight'

from torchvision.transforms import v2
from torchvision.io import decode_image

torch.manual_seed(1)

# If you're trying to run that on Colab, you can download the assets and the
# helpers from https://github.com/pytorch/vision/tree/main/gallery/
from helpers import plot
img = decode_image(str(Path('../assets') / 'astronaut.jpg'))
print(f"{type(img) = }, {img.dtype = }, {img.shape = }")
type(img) = <class 'torch.Tensor'>, img.dtype = torch.uint8, img.shape = torch.Size([3, 512, 512])

基礎知識

Torchvision transforms 的行為類似於常規的 torch.nn.Module(事實上,它們中的大多數都是如此):例項化一個 transform,傳入一個輸入,得到一個轉換後的輸出。

transform = v2.RandomCrop(size=(224, 224))
out = transform(img)

plot([img, out])
plot transforms getting started

我只想做影像分類

如果您只關心影像分類,那麼事情會非常簡單。一個基本的分類流程可能如下所示:

transforms = v2.Compose([
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
out = transforms(img)

plot([img, out])
plot transforms getting started

這種轉換流程通常作為 transform 引數傳遞給 Datasets,例如 ImageNet(..., transform=transforms)

差不多就是這樣了。接下來,請閱讀我們的 主文件,瞭解更多關於推薦實踐和約定,或者探索更多 示例,例如如何使用增強轉換,如 CutMix 和 MixUp

注意

如果您已經依賴 torchvision.transforms v1 API,我們建議您 切換到新的 v2 transforms。這非常簡單:v2 transforms 完全相容 v1 API,所以您只需要更改匯入即可!

影片、邊界框、掩碼、關鍵點

來自 torchvision.transforms.v2 名稱空間中的 Torchvision transforms 支援影像分類以外的任務:它們還可以轉換旋轉或軸對齊的邊界框、分割/檢測掩碼、影片和關鍵點。

讓我們簡要看一個帶有邊界框的檢測示例。

from torchvision import tv_tensors  # we'll describe this a bit later, bare with us

boxes = tv_tensors.BoundingBoxes(
    [
        [15, 10, 370, 510],
        [275, 340, 510, 510],
        [130, 345, 210, 425]
    ],
    format="XYXY", canvas_size=img.shape[-2:])

transforms = v2.Compose([
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomPhotometricDistort(p=1),
    v2.RandomHorizontalFlip(p=1),
])
out_img, out_boxes = transforms(img, boxes)
print(type(boxes), type(out_boxes))

plot([(img, boxes), (out_img, out_boxes)])
plot transforms getting started
<class 'torchvision.tv_tensors._bounding_boxes.BoundingBoxes'> <class 'torchvision.tv_tensors._bounding_boxes.BoundingBoxes'>

上面的示例側重於物件檢測。但如果我們有用於物件分割或語義分割的掩碼(torchvision.tv_tensors.Mask)或影片(torchvision.tv_tensors.Video),我們可以以完全相同的方式將它們傳遞給 transforms。

現在您可能還有幾個問題:什麼是 TVTensors?如何使用它們?這些 transforms 的預期輸入/輸出是什麼?我們將在接下來的部分回答這些問題。

什麼是 TVTensors?

TVTensors 是 torch.Tensor 的子類。可用的 TVTensors 有 ImageBoundingBoxesMaskVideoKeyPoints

TVTensors 的外觀和感覺與普通張量一樣——它們就是張量。在普通 torch.Tensor 上支援的所有操作,例如 .sum() 或任何 torch.* 運算子,同樣適用於 TVTensors。

img_dp = tv_tensors.Image(torch.randint(0, 256, (3, 256, 256), dtype=torch.uint8))

print(f"{isinstance(img_dp, torch.Tensor) = }")
print(f"{img_dp.dtype = }, {img_dp.shape = }, {img_dp.sum() = }")
isinstance(img_dp, torch.Tensor) = True
img_dp.dtype = torch.uint8, img_dp.shape = torch.Size([3, 256, 256]), img_dp.sum() = tensor(25087958)

這些 TVTensor 類是 transforms 的核心:為了轉換給定的輸入,transforms 首先檢視物件的,然後相應地分派到適當的實現。

您目前不需要了解更多關於 TVTensors 的知識,但想要了解更多的高階使用者可以參考 TVTensors FAQ

我應該傳入什麼作為輸入?

上面我們已經看到了兩個示例:一個示例中我們傳入了單個影像作為輸入,即 out = transforms(img);另一個示例中我們同時傳入了影像和邊界框,即 out_img, out_boxes = transforms(img, boxes)

事實上,transforms 支援任意輸入結構。輸入可以是單個影像、元組、任意巢狀的字典……幾乎任何東西。輸出將返回相同的結構。下面,我們使用相同的檢測 transforms,但將元組(影像,目標字典)作為輸入,並獲得了相同的輸出結構。

target = {
    "boxes": boxes,
    "labels": torch.arange(boxes.shape[0]),
    "this_is_ignored": ("arbitrary", {"structure": "!"})
}

# Re-using the transforms and definitions from above.
out_img, out_target = transforms(img, target)

plot([(img, target["boxes"]), (out_img, out_target["boxes"])])
print(f"{out_target['this_is_ignored']}")
plot transforms getting started
('arbitrary', {'structure': '!'})

我們傳入了一個元組,所以我們得到一個元組返回,第二個元素是轉換後的目標字典。Transforms 並不真正關心輸入的結構;正如上面提到的,它們只關心物件的型別並相應地轉換它們。

*外部*物件,如字串或整數,將被簡單地傳遞。這可能很有用,例如,如果您想在除錯時將路徑與每個樣本關聯起來!

注意

免責宣告:此說明稍微高階,初次閱讀時可以安全跳過。

純粹的 torch.Tensor 物件通常被視為影像(或對於特定於影片的 transforms,被視為影片)。事實上,您可能已經注意到,在上面的程式碼中,我們根本沒有使用 Image 類,但我們的影像仍然得到了正確轉換。Transforms 遵循以下邏輯來確定一個純張量是應被視為影像(或影片),還是僅被忽略:

  • 如果輸入中存在 ImageVideoPIL.Image.Image 例項,所有其他純張量都將被傳遞。

  • 如果不存在 ImageVideo 例項,只有第一個純 torch.Tensor 將被轉換為影像或影片,而所有其他張量將被傳遞。這裡的“第一個”意味著“在深度優先遍歷中的第一個”。

這正是上面檢測示例中所發生的情況:第一個純張量是影像,因此它被正確轉換,而所有其他純張量例項,如 labels,都被傳遞(儘管標籤仍然可以被某些 transforms 轉換,如 SanitizeBoundingBoxes!)。

Transforms 和 Datasets 的互操作性

粗略地說,datasets 的輸出必須對應於 transforms 的輸入。如何實現這一點取決於您使用的是 torchvision 的內建資料集,還是您自己的自定義資料集。

使用內建資料集

如果您只是做影像分類,您不需要做任何事情。只需使用資料集的 transform 引數,例如 ImageNet(..., transform=transforms),您就可以開始使用了。

Torchvision 還支援物件檢測或分割的資料集,如 torchvision.datasets.CocoDetection。這些資料集在 torchvision.transforms.v2 模組和 TVTensors 出現之前就已存在,因此它們預設不會返回 TVTensors。

一個強制這些資料集返回 TVTensors 並使其與 v2 transforms 相容的簡單方法是使用 torchvision.datasets.wrap_dataset_for_transforms_v2() 函式。

from torchvision.datasets import CocoDetection, wrap_dataset_for_transforms_v2

dataset = CocoDetection(..., transforms=my_transforms)
dataset = wrap_dataset_for_transforms_v2(dataset)
# Now the dataset returns TVTensors!

使用您自己的資料集

如果您有自定義資料集,那麼您需要將您的物件轉換為相應的 TVTensor 類。建立 TVTensor 例項非常簡單,有關更多詳細資訊,請參閱 如何構造一個 TVTensor?

您可以在以下兩個主要地方實現該轉換邏輯:

  • 在資料集的 __getitem__ 方法的末尾,在返回樣本之前(或透過子類化資料集)。

  • 作為您的 transforms 管道的第一步。

無論哪種方式,邏輯都將取決於您的具體資料集。

指令碼總執行時間:(0 分鐘 0.837 秒)

由 Sphinx-Gallery 生成的畫廊

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源