評價此頁

基於 PyTorch 的 USB 半監督學習#

建立日期: 2023 年 12 月 07 日 | 最後更新: 2024 年 03 月 07 日 | 最後驗證: 未驗證

作者: Hao Chen

Unified Semi-supervised learning Benchmark (USB) 是一個基於 PyTorch 構建的半監督學習 (SSL) 框架。基於 PyTorch 提供的 Datasets 和 Modules,USB 成為了一個靈活、模組化且易於使用的半監督學習框架。它支援各種半監督學習演算法,包括 FixMatchFreeMatchDeFixMatchSoftMatch 等。它還支援各種不平衡半監督學習演算法。USB 中包含了計算機視覺、自然語言處理和語音處理不同資料集的基準測試結果。

本教程將引導您瞭解 USB 閃電包的基礎知識。讓我們開始使用預訓練的 Vision Transformers (ViT) 在 CIFAR-10 上訓練 FreeMatch/SoftMatch 模型!我們將展示如何輕鬆更改半監督演算法並在不平衡資料集上進行訓練。

USB framework illustration

半監督學習中 FreeMatchSoftMatch 簡介#

這裡我們簡要介紹 FreeMatchSoftMatch。首先,我們介紹一種著名的半監督學習基線 FixMatchFixMatch 是一個非常簡單的半監督學習框架,它利用強增強來為無標籤資料生成偽標籤。它採用置信度閾值策略,以固定的閾值過濾掉低置信度的偽標籤。FreeMatchSoftMatch 是改進 FixMatch 的兩種演算法。FreeMatch 提出了一種自適應閾值策略來替代 FixMatch 中的固定閾值策略。自適應閾值策略根據模型在每個類上的學習狀態逐漸提高閾值。SoftMatch 將置信度閾值策略的思想吸收為一種加權機制。它提出了一種高斯加權機制來克服偽標籤的數量-質量權衡問題。在本教程中,我們將使用 USB 來訓練 FreeMatchSoftMatch

使用 USB 在 CIFAR-10 上僅用 40 個標籤訓練 FreeMatch/SoftMatch#

USB 易於使用和擴充套件,對於小型團隊來說經濟實惠,並且對於開發和評估 SSL 演算法非常全面。USB 提供了基於一致性正則化的 14 種 SSL 演算法的實現,以及來自 CV、NLP 和 Audio 領域的 15 個評估任務。它採用模組化設計,允許使用者透過新增新演算法和任務來輕鬆擴充套件該包。它還支援 Python API,以便更容易地將不同的 SSL 演算法適配到新資料上。

現在,讓我們使用 USB 在 CIFAR-10 上訓練 FreeMatchSoftMatch。首先,我們需要安裝 USB 包 semilearn 並從 USB 匯入必要的 API 函式。如果您在 Google Colab 中執行此程式,請執行以下命令來安裝 semilearn!pip install semilearn

以下是我們將在 semilearn 中使用的函式列表

  • get_dataset 用於載入資料集,這裡我們使用 CIFAR-10

  • get_data_loader 用於建立訓練(有標籤和無標籤)和測試資料

loaders,無標籤訓練的 loader 將提供無標籤資料的強增強和弱增強 - get_net_builder 用於建立模型,這裡我們使用預訓練的 ViT - get_algorithm 用於建立半監督學習演算法,這裡我們使用 FreeMatchSoftMatch - get_config:用於獲取演算法的預設配置 - Trainer:一個用於在資料集上訓練和評估演算法的 Trainer 類

請注意,使用 semilearn 包進行訓練需要 CUDA 啟用的後端。有關在 Google Colab 中啟用 CUDA 的說明,請參閱 在 Google Colab 中啟用 CUDA

import semilearn
from semilearn import get_dataset, get_data_loader, get_net_builder, get_algorithm, get_config, Trainer

匯入必要的函式後,我們首先設定演算法的超引數。

config = {
    'algorithm': 'freematch',
    'net': 'vit_tiny_patch2_32',
    'use_pretrain': True,
    'pretrain_path': 'https://github.com/microsoft/Semi-supervised-learning/releases/download/v.0.0.0/vit_tiny_patch2_32_mlp_im_1k_32.pth',

    # optimization configs
    'epoch': 1,
    'num_train_iter': 500,
    'num_eval_iter': 500,
    'num_log_iter': 50,
    'optim': 'AdamW',
    'lr': 5e-4,
    'layer_decay': 0.5,
    'batch_size': 16,
    'eval_batch_size': 16,


    # dataset configs
    'dataset': 'cifar10',
    'num_labels': 40,
    'num_classes': 10,
    'img_size': 32,
    'crop_ratio': 0.875,
    'data_dir': './data',
    'ulb_samples_per_class': None,

    # algorithm specific configs
    'hard_label': True,
    'T': 0.5,
    'ema_p': 0.999,
    'ent_loss_ratio': 0.001,
    'uratio': 2,
    'ulb_loss_ratio': 1.0,

    # device configs
    'gpu': 0,
    'world_size': 1,
    'distributed': False,
    "num_workers": 4,
}
config = get_config(config)

然後,我們載入資料集併為訓練和測試建立資料載入器。並指定要使用的模型和演算法。

dataset_dict = get_dataset(config, config.algorithm, config.dataset, config.num_labels, config.num_classes, data_dir=config.data_dir, include_lb_to_ulb=config.include_lb_to_ulb)
train_lb_loader = get_data_loader(config, dataset_dict['train_lb'], config.batch_size)
train_ulb_loader = get_data_loader(config, dataset_dict['train_ulb'], int(config.batch_size * config.uratio))
eval_loader = get_data_loader(config, dataset_dict['eval'], config.eval_batch_size)
algorithm = get_algorithm(config,  get_net_builder(config.net, from_name=False), tb_log=None, logger=None)

現在我們可以開始在帶有 40 個標籤的 CIFAR-10 上訓練演算法了。我們訓練 500 個迭代,並每 500 個迭代進行一次評估。

trainer = Trainer(config, algorithm)
trainer.fit(train_lb_loader, train_ulb_loader, eval_loader)

最後,讓我們在驗證集上評估訓練好的模型。在僅使用 40 個 CIFAR-10 標籤上用 FreeMatch 訓練 500 個迭代後,我們得到了一個分類器,在驗證集上達到了約 87% 的準確率。

trainer.evaluate(eval_loader)

使用 USB 在不平衡的 CIFAR-10 上使用特定的不平衡演算法訓練 SoftMatch#

現在,假設我們有不平衡的 CIFAR-10 有標籤集和無標籤集,並且我們想在上面訓練一個 SoftMatch 模型。我們透過將 lb_imb_ratioulb_imb_ratio 設定為 10 來建立一個不平衡的有標籤集和不平衡的無標籤集。此外,我們將 algorithm 替換為 softmatch,並將 imbalanced 設定為 True

config = {
    'algorithm': 'softmatch',
    'net': 'vit_tiny_patch2_32',
    'use_pretrain': True,
    'pretrain_path': 'https://github.com/microsoft/Semi-supervised-learning/releases/download/v.0.0.0/vit_tiny_patch2_32_mlp_im_1k_32.pth',

    # optimization configs
    'epoch': 1,
    'num_train_iter': 500,
    'num_eval_iter': 500,
    'num_log_iter': 50,
    'optim': 'AdamW',
    'lr': 5e-4,
    'layer_decay': 0.5,
    'batch_size': 16,
    'eval_batch_size': 16,


    # dataset configs
    'dataset': 'cifar10',
    'num_labels': 1500,
    'num_classes': 10,
    'img_size': 32,
    'crop_ratio': 0.875,
    'data_dir': './data',
    'ulb_samples_per_class': None,
    'lb_imb_ratio': 10,
    'ulb_imb_ratio': 10,
    'ulb_num_labels': 3000,

    # algorithm specific configs
    'hard_label': True,
    'T': 0.5,
    'ema_p': 0.999,
    'ent_loss_ratio': 0.001,
    'uratio': 2,
    'ulb_loss_ratio': 1.0,

    # device configs
    'gpu': 0,
    'world_size': 1,
    'distributed': False,
    "num_workers": 4,
}
config = get_config(config)

然後,我們重新載入資料集併為訓練和測試建立資料載入器。並指定要使用的模型和演算法。

dataset_dict = get_dataset(config, config.algorithm, config.dataset, config.num_labels, config.num_classes, data_dir=config.data_dir, include_lb_to_ulb=config.include_lb_to_ulb)
train_lb_loader = get_data_loader(config, dataset_dict['train_lb'], config.batch_size)
train_ulb_loader = get_data_loader(config, dataset_dict['train_ulb'], int(config.batch_size * config.uratio))
eval_loader = get_data_loader(config, dataset_dict['eval'], config.eval_batch_size)
algorithm = get_algorithm(config,  get_net_builder(config.net, from_name=False), tb_log=None, logger=None)

現在我們可以開始在帶有 40 個標籤的 CIFAR-10 上訓練演算法了。我們訓練 500 個迭代,並每 500 個迭代進行一次評估。

trainer = Trainer(config, algorithm)
trainer.fit(train_lb_loader, train_ulb_loader, eval_loader)

最後,讓我們在驗證集上評估訓練好的模型。

trainer.evaluate(eval_loader)

參考文獻: - [1] USB: microsoft/Semi-supervised-learning - [2] Kihyuk Sohn 等人. FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence - [3] Yidong Wang 等人. FreeMatch: Self-adaptive Thresholding for Semi-supervised Learning - [4] Hao Chen 等人. SoftMatch: Addressing the Quantity-Quality Trade-off in Semi-supervised Learning