評價此頁

torch.optim#

創建於: 2025年6月13日 | 最後更新於: 2025年8月24日

torch.optim 是一個實現了各種最佳化演算法的包。

大多數常用方法已得到支援,並且介面足夠通用,以便將來可以輕鬆整合更復雜的方法。

如何使用最佳化器#

要使用 torch.optim,您需要構造一個最佳化器物件,該物件將儲存當前狀態並根據計算出的梯度更新引數。

構造它#

要構造一個 Optimizer,您需要為其提供一個包含要最佳化的引數(所有引數都應為 Parameter)或命名引數((str, Parameter) 的元組)的可迭代物件。然後,您可以指定特定於最佳化器的選項,例如學習率、權重衰減等。

示例

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)

命名引數示例

optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([('layer0', var1), ('layer1', var2)], lr=0.0001)

按引數分組的選項#

Optimizer 還支援按引數分組指定選項。為此,請不要傳遞 Variable 的可迭代物件,而是傳遞 dict 的可迭代物件。每個字典將定義一個單獨的引數組,並應包含一個 params 鍵,其中包含屬於該組的引數列表。其他鍵應與最佳化器接受的關鍵字引數匹配,並將用作此組的最佳化選項。

例如,當一個人想要為每個層指定不同的學習率時,這非常有用。

optim.SGD([
    {'params': model.base.parameters(), 'lr': 1e-2},
    {'params': model.classifier.parameters()}
], lr=1e-3, momentum=0.9)

optim.SGD([
    {'params': model.base.named_parameters(), 'lr': 1e-2},
    {'params': model.classifier.named_parameters()}
], lr=1e-3, momentum=0.9)

這意味著 model.base 的引數將使用 1e-2 的學習率,而 model.classifier 的引數將保持預設學習率 1e-3。最後,將對所有引數使用 0.9 的動量。

注意

您仍然可以傳遞選項作為關鍵字引數。未在組中覆蓋的選項將使用它們作為預設值。當您只想更改一個選項,同時保持引數組之間的所有其他選項一致時,這很有用。

另請考慮以下與引數區分懲罰相關的示例。請記住,parameters() 返回一個可迭代物件,其中包含所有可學習的引數,包括可能需要區分懲罰的偏置和其他引數。為了解決這個問題,可以為每個引數組指定單獨的懲罰權重。

bias_params = [p for name, p in self.named_parameters() if 'bias' in name]
others = [p for name, p in self.named_parameters() if 'bias' not in name]

optim.SGD([
    {'params': others},
    {'params': bias_params, 'weight_decay': 0}
], weight_decay=1e-2, lr=1e-2)

這樣,偏置項與非偏置項分開,並且為偏置項專門設定了 0weight_decay,以避免對該組進行任何懲罰。

執行最佳化步驟#

所有最佳化器都實現了一個 step() 方法,用於更新引數。它有兩種用法:

optimizer.step()#

這是大多數最佳化器支援的簡化版本。在計算出梯度後(例如使用 backward()),即可呼叫該函式。

示例

for input, target in dataset:
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()

optimizer.step(closure)#

某些最佳化演算法,例如共軛梯度法和 L-BFGS,需要多次重新評估函式,因此您必須傳遞一個閉包,以便它們可以重新計算模型。閉包應清除梯度,計算損失並返回它。

示例

for input, target in dataset:
    def closure():
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        return loss
    optimizer.step(closure)

基類#

class torch.optim.Optimizer(params, defaults)[source]#

所有最佳化器的基類。

警告

需要將引數指定為具有確定性順序的物件,並且該順序在不同執行之間是一致的。不滿足這些屬性的物件示例包括集合以及字典值的迭代器。

引數
  • params (iterable) – 一個 torch.Tensordict 的可迭代物件。指定要最佳化的 Tensor。

  • defaults (dict[str, Any]) – (dict):一個包含最佳化選項預設值的字典(當引數組未指定時使用)。

Optimizer.add_param_group

向最佳化器的 param_groups 新增一個引數組。

Optimizer.load_state_dict

載入最佳化器狀態。

Optimizer.register_load_state_dict_pre_hook

註冊一個 load_state_dict 前置鉤子,該鉤子將在呼叫 load_state_dict() 之前呼叫。它應該具有以下簽名:。

Optimizer.register_load_state_dict_post_hook

註冊一個 load_state_dict 後置鉤子,該鉤子將在呼叫 load_state_dict() 之後呼叫。它應該具有以下簽名:。

Optimizer.state_dict

將最佳化器的狀態作為 dict 返回。

Optimizer.register_state_dict_pre_hook

註冊一個 state_dict 前置鉤子,該鉤子將在呼叫 state_dict() 之前呼叫。

Optimizer.register_state_dict_post_hook

註冊一個 state_dict 後置鉤子,該鉤子將在呼叫 state_dict() 之後呼叫。

Optimizer.step

執行一次最佳化步驟來更新引數。

Optimizer.register_step_pre_hook

註冊一個最佳化器步驟預鉤子,它將在最佳化器步驟之前被呼叫。

Optimizer.register_step_post_hook

註冊一個最佳化器步驟後鉤子,它將在最佳化器步驟之後被呼叫。

Optimizer.zero_grad

重置所有已最佳化 torch.Tensor 的梯度。

演算法#

Adadelta

實現了 Adadelta 演算法。

Adafactor

實現了 Adafactor 演算法。

Adagrad

實現了 Adagrad 演算法。

Adam

實現了 Adam 演算法。

AdamW

實現了 AdamW 演算法,其中權重衰減不累積到動量或方差中。

SparseAdam

SparseAdam 實現了一個 Adam 演算法的掩碼版本,適用於稀疏梯度。

Adamax

實現了 Adamax 演算法(基於無窮範數的 Adam 變體)。

ASGD

實現了平均隨機梯度下降。

LBFGS

實現了 L-BFGS 演算法。

Muon

實現了 Muon 演算法。

NAdam

實現了 NAdam 演算法。

RAdam

實現了 RAdam 演算法。

RMSprop

實現了 RMSprop 演算法。

Rprop

實現了彈性反向傳播演算法。

SGD

實現了隨機梯度下降(可選帶動量)。

我們的許多演算法都有各種針對性能、可讀性和/或通用性進行最佳化的實現,因此,如果我們沒有指定任何特定的實現,我們會嘗試預設使用當前裝置上通常最快的實現。

我們有 3 個主要的實現類別:for-loop、foreach(多張量)和 fused。最直接的實現是在具有大型計算塊的引數上使用 for-loop。For-loop 通常比我們的 foreach 實現慢,後者將引數合併到一個多張量中,一次性執行大型計算塊,從而節省了許多順序核心呼叫。我們的一些最佳化器甚至有更快的 fused 實現,它們將大型計算塊融合到一個核心中。我們可以將 foreach 實現視為水平融合,而 fused 實現則在此基礎上進行垂直融合。

通常,3 種實現的效能順序是 fused > foreach > for-loop。因此,在適用時,我們預設使用 foreach 而不是 for-loop。適用意味著 foreach 實現可用,使用者沒有指定任何特定於實現的 kwargs(例如,fused、foreach、differentiable),並且所有張量都是原生的。請注意,雖然 fused 應該比 foreach 更快,但這些實現較新,我們希望在全面啟用之前讓它們有更多的時間來完善。我們在下表總結了每種實現的穩定性狀態,歡迎您嘗試!

下表顯示了每種演算法可用的和預設的實現。

演算法

預設

有 foreach 嗎?

有 fused 嗎?

Adadelta

foreach

Adafactor

for-loop

Adagrad

foreach

是(僅限 CPU)

Adam

foreach

AdamW

foreach

SparseAdam

for-loop

Adamax

foreach

ASGD

foreach

LBFGS

for-loop

Muon

for-loop

NAdam

foreach

RAdam

foreach

RMSprop

foreach

Rprop

foreach

SGD

foreach

下表顯示了 fused 實現的穩定性狀態。

演算法

CPU

CUDA

MPS

Adadelta

不支援

不支援

不支援

Adafactor

不支援

不支援

不支援

Adagrad

beta

不支援

不支援

Adam

beta

穩定

beta

AdamW

beta

穩定

beta

SparseAdam

不支援

不支援

不支援

Adamax

不支援

不支援

不支援

ASGD

不支援

不支援

不支援

LBFGS

不支援

不支援

不支援

Muon

不支援

不支援

不支援

NAdam

不支援

不支援

不支援

RAdam

不支援

不支援

不支援

RMSprop

不支援

不支援

不支援

Rprop

不支援

不支援

不支援

SGD

beta

beta

beta

如何調整學習率#

torch.optim.lr_scheduler.LRScheduler 提供了幾種根據 epoch 數量調整學習率的方法。torch.optim.lr_scheduler.ReduceLROnPlateau 允許根據一些驗證度量動態降低學習率。

學習率排程應在最佳化器更新後應用;例如,您的程式碼應這樣編寫:

示例

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = ExponentialLR(optimizer, gamma=0.9)

for epoch in range(20):
    for input, target in dataset:
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
    scheduler.step()

大多數學習率排程器可以連續呼叫(也稱為鏈式排程器)。其結果是每個排程器按順序應用於前一個排程器獲得學習率。

示例

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler1 = ExponentialLR(optimizer, gamma=0.9)
scheduler2 = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)

for epoch in range(20):
    for input, target in dataset:
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
    scheduler1.step()
    scheduler2.step()

在文件的許多地方,我們將使用以下模板來引用排程器演算法。

>>> scheduler = ...
>>> for epoch in range(100):
>>>     train(...)
>>>     validate(...)
>>>     scheduler.step()

警告

在 PyTorch 1.1.0 之前,學習率排程器應在最佳化器更新之前呼叫;1.1.0 以向後不相容的方式改變了此行為。如果您在更新到 PyTorch 1.1.0 後無法重現結果,請檢查您是否在錯誤的時間呼叫了 scheduler.step()

lr_scheduler.LRScheduler

在最佳化過程中調整學習率。

lr_scheduler.LambdaLR

設定初始學習率。

lr_scheduler.MultiplicativeLR

透過指定函式中的因子來乘以每個引數組的學習率。

lr_scheduler.StepLR

每 step_size 個 epoch,將每個引數組的學習率按 gamma 衰減。

lr_scheduler.MultiStepLR

當 epoch 數量達到 milestones 之一時,將每個引數組的學習率按 gamma 衰減一次。

lr_scheduler.ConstantLR

將每個引數組的學習率乘以一個小的常數因子。

lr_scheduler.LinearLR

透過線性改變小的乘法因子來衰減每個引數組的學習率。

lr_scheduler.ExponentialLR

每 epoch,將每個引數組的學習率按 gamma 衰減。

lr_scheduler.PolynomialLR

使用給定 total_iters 中的多項式函式來衰減每個引數組的學習率。

lr_scheduler.CosineAnnealingLR

使用餘弦退火排程設定每個引數組的學習率。

lr_scheduler.ChainedScheduler

連結一系列學習率排程器。

lr_scheduler.SequentialLR

包含一系列排程器,這些排程器預計在最佳化過程中按順序呼叫。

lr_scheduler.ReduceLROnPlateau

當某個度量停止改進時,降低學習率。

lr_scheduler.CyclicLR

根據迴圈學習率策略(CLR)設定每個引數組的學習率。

lr_scheduler.OneCycleLR

根據 1cycle 學習率策略設定每個引數組的學習率。

lr_scheduler.CosineAnnealingWarmRestarts

使用餘弦退火排程設定每個引數組的學習率。

如何利用命名引數載入最佳化器 state_dict#

函式 load_state_dict() 會儲存從載入的 state_dict 中可選的 param_names 內容(如果存在)。但是,載入最佳化器狀態的過程不受影響,因為引數的順序很重要,可以保持相容性(以防順序不同)。要利用載入的 state_dict 中的已載入引數名稱,需要根據期望的行為實現自定義 register_load_state_dict_pre_hook

這在模型架構發生變化但權重和最佳化器狀態需要保持不變的情況下很有用。以下示例演示瞭如何實現此自定義。

示例

class OneLayerModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(3, 4)

    def forward(self, x):
        return self.fc(x)

model = OneLayerModel()
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
# training..
torch.save(optimizer.state_dict(), PATH)

假設 model 實現了一個專家(MoE),我們想複製它並在兩個專家上恢復訓練,這兩個專家都以與 fc 層相同的方式初始化。對於下面的 model2,我們建立了兩個與 fc 相同的層,並透過將 model 的權重和最佳化器狀態載入到 model2fc1fc2 中來恢復訓練(並相應地調整它們)。

class TwoLayerModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(3, 4)
        self.fc2 = nn.Linear(3, 4)

    def forward(self, x):
        return (self.fc1(x) + self.fc2(x)) / 2

model2 = TwoLayerModel()
# adapt and load model weights..
optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9)

為了載入 optimizer2 的 state dict,同時載入先前最佳化器的 state dict,以便 fc1fc2 都將用 fc 最佳化器狀態的副本進行初始化(以便從 fc 繼續訓練每個層),我們可以使用以下鉤子:

def adapt_state_dict_ids(optimizer, state_dict):
    adapted_state_dict = deepcopy(optimizer.state_dict())
    # Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict.
    for k, v in state_dict['param_groups'][0].items():
        if k not in ['params', 'param_names']:
            adapted_state_dict['param_groups'][0][k] = v

    lookup_dict = {
        'fc1.weight': 'fc.weight',
        'fc1.bias': 'fc.bias',
        'fc2.weight': 'fc.weight',
        'fc2.bias': 'fc.bias'
    }
    clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()}
    for param_id, param_name in zip(
            optimizer.state_dict()['param_groups'][0]['params'],
            optimizer.state_dict()['param_groups'][0]['param_names']):
        name_in_loaded = lookup_dict[param_name]
        index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded)
        id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list]
        # Copy the state of the corresponding parameter
        if id_in_loaded in state_dict['state']:
            adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded])

    return adapted_state_dict

optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids)
optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict

這確保了在模型載入期間將使用適應的 state_dict,其中包含 model2 的層的正確狀態。請注意,此程式碼是專門為此示例設計的(例如,假設只有一個引數組),其他情況可能需要不同的調整。

以下示例顯示瞭如何在模型結構更改時處理載入的 state dict 中缺失的引數。Model_bypass 添加了一個新的 bypass 層,該層在原始 Model1 中不存在。為了恢復訓練,使用自定義的 adapt_state_dict_missing_param 鉤子來適應最佳化器的 state_dict,確保現有引數對映正確,而缺失的引數(如示例中初始化的 bypass 層)保持不變。這種方法使得即使模型發生變化,也能平滑地載入和恢復最佳化器狀態。新新增的 bypass 層將從頭開始訓練。

class Model1(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(5, 5)

    def forward(self, x):
        return self.fc(x) + x


model = Model1()
optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
# training..
torch.save(optimizer.state_dict(), PATH)

class Model_bypass(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(5, 5)
        self.bypass = nn.Linear(5, 5, bias=False)
        torch.nn.init.eye_(self.bypass.weight)

    def forward(self, x):
        return self.fc(x) + self.bypass(x)

model2 = Model_bypass()
optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9)

def adapt_state_dict_missing_param(optimizer, state_dict):
    adapted_state_dict = deepcopy(optimizer.state_dict())
    # Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict.
    for k, v in state_dict['param_groups'][0].items():
        if k not in ['params', 'param_names']:
            adapted_state_dict['param_groups'][0][k] = v

    lookup_dict = {
        'fc.weight': 'fc.weight',
        'fc.bias': 'fc.bias',
        'bypass.weight': None,
    }

    clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()}
    for param_id, param_name in zip(
            optimizer.state_dict()['param_groups'][0]['params'],
            optimizer.state_dict()['param_groups'][0]['param_names']):
        name_in_loaded = lookup_dict[param_name]
        if name_in_loaded in state_dict['param_groups'][0]['param_names']:
            index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded)
            id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list]
            # Copy the state of the corresponding parameter
            if id_in_loaded in state_dict['state']:
                adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded])

    return adapted_state_dict

optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids)
optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict

作為第三個示例,該鉤子可以用於根據引數的名稱載入,而不是根據引數的順序(預設方法)。

def names_matching(optimizer, state_dict):
    assert len(state_dict['param_groups']) == len(optimizer.state_dict()['param_groups'])
    adapted_state_dict = deepcopy(optimizer.state_dict())
    for g_ind in range(len(state_dict['param_groups'])):
        assert len(state_dict['param_groups'][g_ind]['params']) == len(
            optimizer.state_dict()['param_groups'][g_ind]['params'])

        for k, v in state_dict['param_groups'][g_ind].items():
            if k not in ['params', 'param_names']:
                adapted_state_dict['param_groups'][g_ind][k] = v

        for param_id, param_name in zip(
                optimizer.state_dict()['param_groups'][g_ind]['params'],
                optimizer.state_dict()['param_groups'][g_ind]['param_names']):
            index_in_loaded_list = state_dict['param_groups'][g_ind]['param_names'].index(param_name)
            id_in_loaded = state_dict['param_groups'][g_ind]['params'][index_in_loaded_list]
            # Copy the state of the corresponding parameter
            if id_in_loaded in state_dict['state']:
                adapted_state_dict['state'][param_id] = deepcopy(state_dict['state'][id_in_loaded])

    return adapted_state_dict

權重平均(SWA 和 EMA)#

torch.optim.swa_utils.AveragedModel 實現隨機權重平均(SWA)和指數移動平均(EMA),torch.optim.swa_utils.SWALR 實現 SWA 學習率排程器,而 torch.optim.swa_utils.update_bn() 是一個在訓練結束時用於更新 SWA/EMA 批次歸一化統計量的實用函式。

SWA 在 Averaging Weights Leads to Wider Optima and Better Generalization 中被提出。

EMA 是一種廣泛用於減少訓練時間的技術,透過減少所需的權重更新次數。它是 Polyak averaging 的一個變種,但使用指數權重而不是迭代之間的相等權重。

構造平均模型#

AveragedModel 類用於計算 SWA 或 EMA 模型的權重。

您可以透過執行以下命令建立一個 SWA 平均模型:

>>> averaged_model = AveragedModel(model)

EMA 模型透過將 `multi_avg_fn` 引數指定為以下方式來構造:

>>> decay = 0.999
>>> averaged_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(decay))

衰減是一個介於 0 和 1 之間的引數,它控制平均引數衰減的速度。如果未提供給 torch.optim.swa_utils.get_ema_multi_avg_fn(),則預設值為 0.999。衰減值應接近 1.0,因為較小的值可能導致最佳化收斂問題。

torch.optim.swa_utils.get_ema_multi_avg_fn() 返回一個函式,該函式將以下 EMA 方程應用於權重:

Wt+1EMA=αWtEMA+(1α)WtmodelW^\textrm{EMA}_{t+1} = \alpha W^\textrm{EMA}_{t} + (1 - \alpha) W^\textrm{model}_t

其中 alpha 是 EMA 衰減。

這裡的 model model 可以是任意的 torch.nn.Module 物件。averaged_model 將跟蹤 model 引數的執行平均值。要更新這些平均值,您應該在 optimizer.step() 之後使用 update_parameters() 函式。

>>> averaged_model.update_parameters(model)

對於 SWA 和 EMA,此呼叫通常在最佳化器 step() 之後不久進行。在 SWA 的情況下,這通常在訓練開始的某些步數內跳過。

自定義平均策略#

預設情況下,torch.optim.swa_utils.AveragedModel 計算您提供的引數的執行平均值,但您也可以使用 `avg_fn` 或 `multi_avg_fn` 引數來自定義平均函式。

  • avg_fn 允許定義一個對每個引數元組(平均引數,模型引數)進行操作的函式,並應返回新的平均引數。

  • multi_avg_fn 允許同時定義更高效的操作,這些操作作用於引數列表元組(平均引數列表,模型引數列表),例如使用 torch._foreach* 函式。此函式必須就地更新平均引數。

在以下示例中,ema_model 使用 avg_fn 引數計算指數移動平均。

>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\
>>>         0.9 * averaged_model_parameter + 0.1 * model_parameter
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)

在以下示例中,ema_model 使用更高效的 multi_avg_fn 引數計算指數移動平均。

>>> ema_model = AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(0.9))

SWA 學習率排程#

通常,在 SWA 中,學習率被設定為一個高常量值。SWALR 是一個學習率排程器,它將學習率衰減到一個固定值,然後保持不變。例如,以下程式碼建立一個排程器,該排程器在每個引數組的 5 個 epoch 內將學習率從初始值線性衰減到 0.05。

>>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, \
>>>         anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05)

您還可以透過設定 anneal_strategy="cos" 來使用餘弦退火到一個固定值,而不是線性退火。

處理批次歸一化#

update_bn() 是一個實用函式,它允許在訓練結束時在給定的資料載入器 loader 上計算 SWA 模型的批次歸一化統計量。

>>> torch.optim.swa_utils.update_bn(loader, swa_model)

update_bn()swa_model 應用於資料載入器中的每個元素,並計算模型中每個批次歸一化層的啟用統計量。

警告

update_bn() 假設資料載入器 loader 中的每個批次是張量,或者是一個張量列表/元組,其中第一個元素是網路 swa_model 應該應用的張量。如果您的資料載入器結構不同,您可以透過在資料集的每個元素上進行前向傳遞(使用 swa_model)來更新 swa_model 的批次歸一化統計量。

總而言之:SWA#

在下面的示例中,swa_model 是累積權重平均值的 SWA 模型。我們將模型訓練總共 300 個 epoch,並在 epoch 160 時切換到 SWA 學習率排程器並開始收集引數的 SWA 平均值。

>>> loader, optimizer, model, loss_fn = ...
>>> swa_model = torch.optim.swa_utils.AveragedModel(model)
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
>>> swa_start = 160
>>> swa_scheduler = SWALR(optimizer, swa_lr=0.05)
>>>
>>> for epoch in range(300):
>>>       for input, target in loader:
>>>           optimizer.zero_grad()
>>>           loss_fn(model(input), target).backward()
>>>           optimizer.step()
>>>       if epoch > swa_start:
>>>           swa_model.update_parameters(model)
>>>           swa_scheduler.step()
>>>       else:
>>>           scheduler.step()
>>>
>>> # Update bn statistics for the swa_model at the end
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
>>> # Use swa_model to make predictions on test data
>>> preds = swa_model(test_input)

總而言之:EMA#

在下面的示例中,ema_model 是 EMA 模型,它以 0.999 的衰減率累積權重的指數衰減平均值。我們將模型訓練總共 300 個 epoch,並立即開始收集 EMA 平均值。

>>> loader, optimizer, model, loss_fn = ...
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, \
>>>             multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999))
>>>
>>> for epoch in range(300):
>>>       for input, target in loader:
>>>           optimizer.zero_grad()
>>>           loss_fn(model(input), target).backward()
>>>           optimizer.step()
>>>           ema_model.update_parameters(model)
>>>
>>> # Update bn statistics for the ema_model at the end
>>> torch.optim.swa_utils.update_bn(loader, ema_model)
>>> # Use ema_model to make predictions on test data
>>> preds = ema_model(test_input)

swa_utils.AveragedModel

為隨機權重平均(SWA)和指數移動平均(EMA)實現平均模型。

swa_utils.SWALR

將每個引數組的學習率衰減到一個固定值。

torch.optim.swa_utils.get_ema_multi_avg_fn(decay=0.999)[source]#

獲取跨多個引數應用指數移動平均(EMA)的函式。

torch.optim.swa_utils.update_bn(loader, model, device=None)[source]#

更新模型中的 BatchNorm running_mean、running_var 緩衝區。

它會遍歷 loader 中的資料一次,以估算模型中 BatchNorm 層的啟用統計量。

引數
  • loader (torch.utils.data.DataLoader) – 用於計算啟用統計量的資料集載入器。每個資料批次都應該是張量,或者是一個列表/元組,其中第一個元素是包含資料的張量。

  • model (torch.nn.Module) – 我們要為其更新 BatchNorm 統計量的模型。

  • device (torch.device, optional) – 如果設定,資料將在傳入 model 之前傳輸到 device

示例

>>> loader, model = ...
>>> torch.optim.swa_utils.update_bn(loader, model)

注意

cite>update_bn 實用函式假定 loader 中的每個資料批次都是張量,或者是張量的列表或元組;在後一種情況下,假定 model.forward() 應該在與資料批次對應的列表或元組的第一個元素上呼叫。