快捷方式

如何編寫自己的 v2 變換

注意

Colab 上嘗試,或 轉到末尾 下載完整的示例程式碼。

本指南將介紹如何編寫與 torchvision transforms V2 API 相容的變換。

from typing import Any, Dict, List

import torch
from torchvision import tv_tensors
from torchvision.transforms import v2

只需建立一個 nn.Module 並重寫 forward 方法

在大多數情況下,只要您瞭解變換將期望的輸入結構,這就可以了。例如,如果您僅進行影像分類,您的變換通常會接受單個影像作為輸入,或者接受 (img, label) 作為輸入。因此,您可以直接將 forward 方法硬編碼為僅接受該輸入,例如:

class MyCustomTransform(torch.nn.Module):
    def forward(self, img, label):
        # Do some transformations
        return new_img, new_label

注意

這意味著,如果您有一個自定義變換已相容 V1 變換(位於 torchvision.transforms 中),它無需任何更改即可與 V2 變換一起使用!

下面我們將更全面地說明這一點,以一個典型的檢測案例為例,其中我們的樣本僅為影像、邊界框和標籤。

class MyCustomTransform(torch.nn.Module):
    def forward(self, img, bboxes, label):  # we assume inputs are always structured like this
        print(
            f"I'm transforming an image of shape {img.shape} "
            f"with bboxes = {bboxes}\n{label = }"
        )
        # Do some transformations. Here, we're just passing though the input
        return img, bboxes, label


transforms = v2.Compose([
    MyCustomTransform(),
    v2.RandomResizedCrop((224, 224), antialias=True),
    v2.RandomHorizontalFlip(p=1),
    v2.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
])

H, W = 256, 256
img = torch.rand(3, H, W)
bboxes = tv_tensors.BoundingBoxes(
    torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]),
    format="XYXY",
    canvas_size=(H, W)
)
label = 3

out_img, out_bboxes, out_label = transforms(img, bboxes, label)
I'm transforming an image of shape torch.Size([3, 256, 256]) with bboxes = BoundingBoxes([[ 0, 10, 10, 20],
               [50, 50, 70, 70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256), clamping_mode=soft)
label = 3
print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }")
Output image shape: torch.Size([3, 224, 224])
out_bboxes = BoundingBoxes([[218,   7, 224,  16],
               [148,  43, 171,  61]], format=BoundingBoxFormat.XYXY, canvas_size=(224, 224), clamping_mode=soft)
out_label = 3

注意

在您的程式碼中使用 TVTensor 類時,請確保您熟悉本節:我有一個 TVTensor,現在我有一個 Tensor。求助!

支援任意輸入結構

在上一節中,我們假設您已經瞭解輸入的結構,並且您可以在程式碼中硬編碼此期望的結構。如果您希望自定義變換儘可能靈活,這可能有點侷限。

內建 Torchvision V2 變換的一個關鍵特性是它們可以接受任意輸入結構並返回相同的結構作為輸出(經過變換的條目)。例如,變換可以接受單個影像,或者一個 (img, label) 元組,或者任意巢狀的字典作為輸入。以下是內建變換 RandomHorizontalFlip 的示例:

structured_input = {
    "img": img,
    "annotations": (bboxes, label),
    "something that will be ignored": (1, "hello"),
    "another tensor that is ignored": torch.arange(10),
}
structured_output = v2.RandomHorizontalFlip(p=1)(structured_input)

assert isinstance(structured_output, dict)
assert structured_output["something that will be ignored"] == (1, "hello")
assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all()
print(f"The input bboxes are:\n{structured_input['annotations'][0]}")
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
The input bboxes are:
BoundingBoxes([[ 0, 10, 10, 20],
               [50, 50, 70, 70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256), clamping_mode=soft)
The transformed bboxes are:
BoundingBoxes([[246,  10, 256,  20],
               [186,  50, 206,  70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256), clamping_mode=soft)

基礎:重寫 transform() 方法

為了支援自定義變換中的任意輸入,您需要繼承自 Transform 並重寫 .transform() 方法(而不是 forward() 方法!)。下面是一個基本示例:

class MyCustomTransform(v2.Transform):
    def transform(self, inpt: Any, params: Dict[str, Any]):
        if type(inpt) == torch.Tensor:
            print(f"I'm transforming an image of shape {inpt.shape}")
            return inpt + 1  # dummy transformation
        elif isinstance(inpt, tv_tensors.BoundingBoxes):
            print(f"I'm transforming bounding boxes! {inpt.canvas_size = }")
            return tv_tensors.wrap(inpt + 100, like=inpt)  # dummy transformation


my_custom_transform = MyCustomTransform()
structured_output = my_custom_transform(structured_input)

assert isinstance(structured_output, dict)
assert structured_output["something that will be ignored"] == (1, "hello")
assert (structured_output["another tensor that is ignored"] == torch.arange(10)).all()
print(f"The input bboxes are:\n{structured_input['annotations'][0]}")
print(f"The transformed bboxes are:\n{structured_output['annotations'][0]}")
I'm transforming an image of shape torch.Size([3, 256, 256])
I'm transforming bounding boxes! inpt.canvas_size = (256, 256)
The input bboxes are:
BoundingBoxes([[ 0, 10, 10, 20],
               [50, 50, 70, 70]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256), clamping_mode=soft)
The transformed bboxes are:
BoundingBoxes([[100, 110, 110, 120],
               [150, 150, 170, 170]], format=BoundingBoxFormat.XYXY, canvas_size=(256, 256), clamping_mode=soft)

需要注意的一個重要事項是,當我們對 structured_input 呼叫 my_custom_transform 時,輸入會被展平,然後每個單獨的部分會被傳遞到 transform()。也就是說,transform() 依次接收輸入影像、邊界框等。在 transform() 中,您可以根據每個輸入的型別來決定如何轉換它們。

如果您好奇為什麼另一個張量(torch.arange())沒有被傳遞到 transform(),請參閱此說明瞭解更多詳情。

高階:make_params() 方法

在呼叫 transform() 對每個輸入進行處理之前,內部會呼叫 make_params() 方法。這通常對於生成隨機引數值很有用。在下面的示例中,我們使用它以 0.5 的機率隨機應用變換:

class MyRandomTransform(MyCustomTransform):
    def __init__(self, p=0.5):
        self.p = p
        super().__init__()

    def make_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
        apply_transform = (torch.rand(size=(1,)) < self.p).item()
        params = dict(apply_transform=apply_transform)
        return params

    def transform(self, inpt: Any, params: Dict[str, Any]):
        if not params["apply_transform"]:
            print("Not transforming anything!")
            return inpt
        else:
            return super().transform(inpt, params)


my_random_transform = MyRandomTransform()

torch.manual_seed(0)
_ = my_random_transform(structured_input)  # transforms
_ = my_random_transform(structured_input)  # doesn't transform
I'm transforming an image of shape torch.Size([3, 256, 256])
I'm transforming bounding boxes! inpt.canvas_size = (256, 256)
Not transforming anything!
Not transforming anything!

注意

對於此類隨機引數生成,在 make_params() 中進行而不是在 transform() 中進行非常重要,這樣對於給定的變換呼叫,相同的 RNG 會以相同的方式應用於所有輸入。如果我們要在 transform() 中執行 RNG,我們將面臨風險,例如變換了影像,但沒有變換邊界框。

make_params() 方法接收所有輸入列表作為引數(此列表中的每個元素稍後都會傳遞給 transform())。您可以使用 flat_inputs 來例如確定輸入資料的維度,使用 query_chw()query_size()

make_params() 應該返回一個字典(或者實際上,任何您想要的東西),然後該字典將被傳遞給 transform()

指令碼總執行時間: (0 分鐘 0.009 秒)

由 Sphinx-Gallery 生成的畫廊

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源