快捷方式

如何編寫自己的 TVTensor 類

注意

Colab 中嘗試,或 跳轉到末尾 下載完整的示例程式碼。

本指南面向高階使用者和下游庫維護者。我們將解釋如何編寫自己的 TVTensor 類,以及如何使其與內建的 Torchvision v2 變換相容。在繼續之前,請確保您已閱讀 TVTensors FAQ

import torch
from torchvision import tv_tensors
from torchvision.transforms import v2

我們將建立一個非常簡單的類,它僅繼承自基類 TVTensor。這足以涵蓋您需要了解的內容,以實現更復雜的用例。如果您需要建立一個承載元資料的類,請檢視 BoundingBoxes 類的 實現方式

class MyTVTensor(tv_tensors.TVTensor):
    pass


my_dp = MyTVTensor([1, 2, 3])
my_dp
MyTVTensor([1., 2., 3.])

現在我們已經定義了自定義 TVTensor 類,我們希望它與內建的 torchvision 變換和函式式 API 相容。為此,我們需要實現一個執行變換核心操作的核心,然後透過 register_kernel() 將其“掛鉤”到我們想要支援的函式式 API 上。

我們在下面說明了這個過程:我們為 MyTVTensor 類的“水平翻轉”操作建立一個核心,並將其註冊到函式式 API。

from torchvision.transforms.v2 import functional as F


@F.register_kernel(functional="hflip", tv_tensor_cls=MyTVTensor)
def hflip_my_tv_tensor(my_dp, *args, **kwargs):
    print("Flipping!")
    out = my_dp.flip(-1)
    return tv_tensors.wrap(out, like=my_dp)

要了解為什麼使用 wrap(),請參閱 我曾擁有一個 TVTensor,但現在我擁有一個 Tensor。救命!。暫時忽略 *args, **kwargs,我們將在下面的 引數轉發,以及確保核心的未來相容性 中進行解釋。

注意

在我們上面的 register_kernel 呼叫中,我們使用字串 functional="hflip" 來引用我們想要掛鉤的函式式 API。我們也可以直接使用函式式 API 本身,即 @register_kernel(functional=F.hflip, ...)

現在我們已經註冊了我們的核心,我們可以對 MyTVTensor 例項呼叫函式式 API。

my_dp = MyTVTensor(torch.rand(3, 256, 256))
_ = F.hflip(my_dp)
Flipping!

我們也可以使用 RandomHorizontalFlip 變換,因為它在內部依賴於 hflip()

t = v2.RandomHorizontalFlip(p=1)
_ = t(my_dp)
Flipping!

注意

我們不能為變換類註冊核心,只能為函式式 API 註冊核心。我們不能註冊變換類是因為一個變換可能在內部依賴於多個函式式 API,所以一般情況下我們不能為給定的類註冊一個單獨的核心。

引數轉發,以及確保核心的未來相容性

您正在掛鉤的函式式 API 是公共的,因此是向後相容的:我們保證這些函式式 API 的引數不會在沒有適當棄用週期的情況下被移除或重新命名。然而,我們不保證向前相容性,並且我們可能會在將來新增新引數。

設想一下,在未來的某個版本中,Torchvision 向其 hflip() 函式式 API 添加了一個新的 inplace 引數。如果您已經定義並註冊了自己的核心,如下所示:

def hflip_my_tv_tensor(my_dp):  # noqa
    print("Flipping!")
    out = my_dp.flip(-1)
    return tv_tensors.wrap(out, like=my_dp)

那麼呼叫 F.hflip(my_dp) 將會失敗,因為 hflip 將會嘗試將新的 inplace 引數傳遞給您的核心,但您的核心不接受它。

因此,我們建議始終在核心的簽名中定義 *args, **kwargs,如上所示。這樣,您的核心將能夠接受我們將來可能新增的任何新引數。(技術上來說,只新增 **kwargs 應該就足夠了)。

指令碼總執行時間: (0 分鐘 0.004 秒)

由 Sphinx-Gallery 生成的畫廊

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源