如何使用 CutMix 和 MixUp¶
CutMix 和 MixUp 是流行的增強策略,可以提高分類準確性。
這些轉換與 Torchvision 的其他轉換略有不同,因為它們需要 **批次** 樣本作為輸入,而不是單個影像。在此示例中,我們將解釋如何使用它們:在 DataLoader 之後,或者作為 collate 函式的一部分。
import torch
from torchvision.datasets import FakeData
from torchvision.transforms import v2
NUM_CLASSES = 100
預處理管道¶
我們將使用一個簡單但典型的影像分類管道
preproc = v2.Compose([
v2.PILToTensor(),
v2.RandomResizedCrop(size=(224, 224), antialias=True),
v2.RandomHorizontalFlip(p=0.5),
v2.ToDtype(torch.float32, scale=True), # to float32 in [0, 1]
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # typically from ImageNet
])
dataset = FakeData(size=1000, num_classes=NUM_CLASSES, transform=preproc)
img, label = dataset[0]
print(f"{type(img) = }, {img.dtype = }, {img.shape = }, {label = }")
type(img) = <class 'torch.Tensor'>, img.dtype = torch.float32, img.shape = torch.Size([3, 224, 224]), label = 67
一個需要注意的重要事項是,CutMix 和 MixUp 都不屬於此預處理管道。我們將在定義 DataLoader 後稍後新增它們。僅作為回顧,如果我們不使用 CutMix 或 MixUp,DataLoader 和訓練迴圈將如下所示。
from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
for images, labels in dataloader:
print(f"{images.shape = }, {labels.shape = }")
print(labels.dtype)
# <rest of the training loop here>
break
images.shape = torch.Size([4, 3, 224, 224]), labels.shape = torch.Size([4])
torch.int64
MixUp 和 CutMix 的使用位置¶
在 DataLoader 之後¶
現在讓我們新增 CutMix 和 MixUp。最簡單的方法是直接在 DataLoader 之後進行:DataLoader 已經為我們批處理了影像和標籤,而這正是這些轉換所期望的輸入。
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
cutmix = v2.CutMix(num_classes=NUM_CLASSES)
mixup = v2.MixUp(num_classes=NUM_CLASSES)
cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])
for images, labels in dataloader:
print(f"Before CutMix/MixUp: {images.shape = }, {labels.shape = }")
images, labels = cutmix_or_mixup(images, labels)
print(f"After CutMix/MixUp: {images.shape = }, {labels.shape = }")
# <rest of the training loop here>
break
Before CutMix/MixUp: images.shape = torch.Size([4, 3, 224, 224]), labels.shape = torch.Size([4])
After CutMix/MixUp: images.shape = torch.Size([4, 3, 224, 224]), labels.shape = torch.Size([4, 100])
請注意,標籤也已轉換:我們從形狀為 (batch_size,) 的批處理標籤變為形狀為 (batch_size, num_classes) 的張量。轉換後的標籤仍然可以直接傳遞給諸如 torch.nn.functional.cross_entropy() 之類的損失函式。
作為 collate 函式的一部分¶
在 DataLoader 之後傳遞轉換是使用 CutMix 和 MixUp 的最簡單方法,但一個缺點是它沒有利用 DataLoader 的多程序。為此,我們可以將這些轉換作為 collate 函式的一部分傳遞(有關 collate 的更多資訊,請參閱 PyTorch 文件)。
from torch.utils.data import default_collate
def collate_fn(batch):
return cutmix_or_mixup(*default_collate(batch))
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2, collate_fn=collate_fn)
for images, labels in dataloader:
print(f"{images.shape = }, {labels.shape = }")
# No need to call cutmix_or_mixup, it's already been called as part of the DataLoader!
# <rest of the training loop here>
break
images.shape = torch.Size([4, 3, 224, 224]), labels.shape = torch.Size([4, 100])
非標準輸入格式¶
到目前為止,我們使用了一個典型的樣本結構,我們將 (images, labels) 作為輸入。MixUp 和 CutMix 預設情況下可以與大多數常見樣本結構(元組,其中第二個引數是張量標籤,或帶有“label[s]”鍵的字典)一起神奇地工作。有關更多詳細資訊,請檢視 labels_getter 引數的文件。
如果您的樣本具有不同的結構,您仍然可以透過將可呼叫物件傳遞給 labels_getter 引數來使用 CutMix 和 MixUp。例如:
batch = {
"imgs": torch.rand(4, 3, 224, 224),
"target": {
"classes": torch.randint(0, NUM_CLASSES, size=(4,)),
"some_other_key": "this is going to be passed-through"
}
}
def labels_getter(batch):
return batch["target"]["classes"]
out = v2.CutMix(num_classes=NUM_CLASSES, labels_getter=labels_getter)(batch)
print(f"{out['imgs'].shape = }, {out['target']['classes'].shape = }")
out['imgs'].shape = torch.Size([4, 3, 224, 224]), out['target']['classes'].shape = torch.Size([4, 100])
指令碼總執行時間: (0 分鐘 0.188 秒)