如何編寫自己的 TVTensor 類¶
本指南面向高階使用者和下游庫維護者。我們將解釋如何編寫自己的 TVTensor 類,以及如何使其與內建的 Torchvision v2 變換相容。在繼續之前,請確保您已閱讀 TVTensors FAQ。
import torch
from torchvision import tv_tensors
from torchvision.transforms import v2
我們將建立一個非常簡單的類,它僅繼承自基類 TVTensor。這足以涵蓋您需要了解的內容,以實現更復雜的用例。如果您需要建立一個承載元資料的類,請檢視 BoundingBoxes 類的 實現方式。
class MyTVTensor(tv_tensors.TVTensor):
pass
my_dp = MyTVTensor([1, 2, 3])
my_dp
MyTVTensor([1., 2., 3.])
現在我們已經定義了自定義 TVTensor 類,我們希望它與內建的 torchvision 變換和函式式 API 相容。為此,我們需要實現一個執行變換核心操作的核心,然後透過 register_kernel() 將其“掛鉤”到我們想要支援的函式式 API 上。
我們在下面說明了這個過程:我們為 MyTVTensor 類的“水平翻轉”操作建立一個核心,並將其註冊到函式式 API。
from torchvision.transforms.v2 import functional as F
@F.register_kernel(functional="hflip", tv_tensor_cls=MyTVTensor)
def hflip_my_tv_tensor(my_dp, *args, **kwargs):
print("Flipping!")
out = my_dp.flip(-1)
return tv_tensors.wrap(out, like=my_dp)
要了解為什麼使用 wrap(),請參閱 我曾擁有一個 TVTensor,但現在我擁有一個 Tensor。救命!。暫時忽略 *args, **kwargs,我們將在下面的 引數轉發,以及確保核心的未來相容性 中進行解釋。
注意
在我們上面的 register_kernel 呼叫中,我們使用字串 functional="hflip" 來引用我們想要掛鉤的函式式 API。我們也可以直接使用函式式 API 本身,即 @register_kernel(functional=F.hflip, ...)。
現在我們已經註冊了我們的核心,我們可以對 MyTVTensor 例項呼叫函式式 API。
my_dp = MyTVTensor(torch.rand(3, 256, 256))
_ = F.hflip(my_dp)
Flipping!
我們也可以使用 RandomHorizontalFlip 變換,因為它在內部依賴於 hflip()。
t = v2.RandomHorizontalFlip(p=1)
_ = t(my_dp)
Flipping!
注意
我們不能為變換類註冊核心,只能為函式式 API 註冊核心。我們不能註冊變換類是因為一個變換可能在內部依賴於多個函式式 API,所以一般情況下我們不能為給定的類註冊一個單獨的核心。
引數轉發,以及確保核心的未來相容性¶
您正在掛鉤的函式式 API 是公共的,因此是向後相容的:我們保證這些函式式 API 的引數不會在沒有適當棄用週期的情況下被移除或重新命名。然而,我們不保證向前相容性,並且我們可能會在將來新增新引數。
設想一下,在未來的某個版本中,Torchvision 向其 hflip() 函式式 API 添加了一個新的 inplace 引數。如果您已經定義並註冊了自己的核心,如下所示:
def hflip_my_tv_tensor(my_dp): # noqa
print("Flipping!")
out = my_dp.flip(-1)
return tv_tensors.wrap(out, like=my_dp)
那麼呼叫 F.hflip(my_dp) 將會失敗,因為 hflip 將會嘗試將新的 inplace 引數傳遞給您的核心,但您的核心不接受它。
因此,我們建議始終在核心的簽名中定義 *args, **kwargs,如上所示。這樣,您的核心將能夠接受我們將來可能新增的任何新引數。(技術上來說,只新增 **kwargs 應該就足夠了)。
指令碼總執行時間: (0 分鐘 0.004 秒)