torch.optim.Optimizer.load_state_dict#
- Optimizer.load_state_dict(state_dict)[原始碼]#
載入最佳化器狀態。
- 引數
state_dict (dict) – 最佳化器狀態。應該是一個從呼叫
state_dict()返回的物件。
警告
請確保在初始化
torch.optim.lr_scheduler.LRScheduler後呼叫此方法,因為在此之前呼叫會覆蓋載入的學習率。注意
引數的名稱(如果存在於
state_dict()中每個引數組的“param_names”鍵下)不會影響載入過程。要使用引數名稱進行自定義(例如,當載入的狀態字典中的引數與最佳化器中初始化的引數不同時),應實現自定義的register_load_state_dict_pre_hook來相應地調整載入的字典。如果param_names存在於載入的狀態字典param_groups中,它們將被儲存並覆蓋最佳化器狀態中當前存在的名稱。如果它們不存在於載入的狀態字典中,最佳化器的param_names將保持不變。示例
>>> model = torch.nn.Linear(10, 10) >>> optim = torch.optim.SGD(model.parameters(), lr=3e-4) >>> scheduler1 = torch.optim.lr_scheduler.LinearLR( ... optim, ... start_factor=0.1, ... end_factor=1, ... total_iters=20, ... ) >>> scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR( ... optim, ... T_max=80, ... eta_min=3e-5, ... ) >>> lr = torch.optim.lr_scheduler.SequentialLR( ... optim, ... schedulers=[scheduler1, scheduler2], ... milestones=[20], ... ) >>> lr.load_state_dict(torch.load("./save_seq.pt")) >>> # now load the optimizer checkpoint after loading the LRScheduler >>> optim.load_state_dict(torch.load("./save_optim.pt"))