torch.optim.Optimizer.state_dict#
- Optimizer.state_dict()[原始碼]#
將最佳化器的狀態作為
dict返回。它包含兩個條目
state:一個包含當前最佳化狀態的 Dict。其內容在不同的最佳化器類中會有所不同,但有一些共同的特點。例如,狀態是按引數儲存的,而引數本身不儲存。
state是一個對映引數 ID 到一個包含每個引數對應狀態的 Dict 的字典。
param_groups:一個包含所有引數組的 List,其中每個引數組是一個 Dict。每個引數組包含最佳化器特有的元資料,例如學習率和權重衰減,以及組中引數的 ID 列表。如果引數組使用
named_parameters()初始化,則名稱內容也會儲存在狀態字典中。
注意:引數 ID 可能看起來像索引,但它們只是將狀態與 param_group 關聯的 ID。從 state_dict 載入時,最佳化器會按順序匹配 param_group 的
params(int ID)和最佳化器的param_groups(實際的nn.Parameter),以匹配狀態,而無需額外驗證。返回的狀態字典可能看起來像
{ 'state': { 0: {'momentum_buffer': tensor(...), ...}, 1: {'momentum_buffer': tensor(...), ...}, 2: {'momentum_buffer': tensor(...), ...}, 3: {'momentum_buffer': tensor(...), ...} }, 'param_groups': [ { 'lr': 0.01, 'weight_decay': 0, ... 'params': [0] 'param_names' ['param0'] (optional) }, { 'lr': 0.001, 'weight_decay': 0.5, ... 'params': [1, 2, 3] 'param_names': ['param1', 'layer.weight', 'layer.bias'] (optional) } ] }