DDP 通訊鉤子#
創建於:2025 年 6 月 6 日 | 最後更新於:2025 年 6 月 6 日
DDP 通訊鉤子是一個通用的介面,用於透過覆蓋 DistributedDataParallel 中的標準 allreduce 來控制跨工作節點(workers)通訊梯度的方式。提供了幾個內建的通訊鉤子,使用者可以輕鬆應用這些鉤子來最佳化通訊。此外,鉤子介面還可以支援使用者自定義的通訊策略,以滿足更高階的用例。
如何使用通訊鉤子?#
要使用通訊鉤子,使用者只需在訓練迴圈開始之前,讓 DDP 模型註冊該鉤子,如下所示。
torch.nn.parallel.DistributedDataParallel.register_comm_hook()
通訊鉤子操作什麼?#
通訊鉤子提供了一種靈活的方式來 allreduce 梯度。因此,它主要在 allreduce 之前對每個副本上的梯度進行操作,這些梯度會被分桶(bucketized)以增加通訊和計算之間的重疊。特別地,torch.distributed.GradBucket 代表了一個待 allreduce 的梯度張量集合。
- class torch.distributed.GradBucket#
此類主要將一個扁平化的梯度張量(由
buffer()返回)傳遞給 DDP 通訊鉤子。該張量可以進一步分解為該分桶內按引數劃分的張量列表(由get_per_parameter_tensors()返回),以便應用層級操作。
- torch.distributed.GradBucket.index(self: torch._C._distributed_c10d.GradBucket) int#
警告
由於分桶在第一次迭代後會重建,因此在訓練開始時,不應依賴於索引。
- 返回
儲存了幾個連續層梯度的分桶的索引。所有梯度都被分桶。
- torch.distributed.GradBucket.buffer(self: torch._C._distributed_c10d.GradBucket) torch.Tensor#
- 返回
一個扁平化的 1D
torch.Tensor緩衝區,可以進一步分解為該分桶內按引數劃分的張量列表。
- torch.distributed.GradBucket.gradients(self: torch._C._distributed_c10d.GradBucket) list[torch.Tensor]#
- 返回
一個
torch.Tensor列表。列表中的每個張量對應一個梯度。
- torch.distributed.GradBucket.is_last(self: torch._C._distributed_c10d.GradBucket) bool#
- 返回
此分桶是否是迭代中最後一個要 allreduce 的分桶。這也意味著此分桶對應於前向傳播中的前幾個層。
- torch.distributed.GradBucket.set_buffer(self: torch._C._distributed_c10d.GradBucket, buffer: torch.Tensor) None#
用輸入的張量緩衝區替換分桶中的張量。
- torch.distributed.GradBucket.parameters(self: torch._C._distributed_c10d.GradBucket) list[torch.Tensor]#
- 返回
一個
torch.Tensor列表。列表中的每個張量對應一個模型引數。
預設通訊鉤子#
預設通訊鉤子是簡單的**無狀態**鉤子,因此 register_comm_hook 中的 state 引數要麼是程序組(process group),要麼是 None。輸入的 bucket 是一個 torch.distributed.GradBucket 物件。
- torch.distributed.algorithms.ddp_comm_hooks.default_hooks.allreduce_hook(process_group, bucket)[source]#
使用
GradBucket張量呼叫allreduce。一旦梯度張量在所有工作節點上聚合完畢,其
then回撥函式會計算平均值並返回結果。如果使用者註冊了這個 DDP 通訊鉤子,DDP 的結果預期與未註冊鉤子時相同。因此,這不會改變 DDP 的行為,使用者可以將其用作參考,或者修改此鉤子以記錄有用的資訊或用於其他目的,同時不影響 DDP 的行為。
- 示例:
>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
- torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook(process_group, bucket)[source]#
透過將
GradBucket轉換為torch.float16併除以程序組大小來進行壓縮。這個 DDP 通訊鉤子實現了一種簡單的梯度壓縮方法,它將
GradBucket張量轉換為半精度浮點格式 (torch.float16),然後除以程序組大小。它對這些float16梯度張量進行 allreduce。一旦壓縮的梯度張量 allreduce 完成,鏈式回撥函式decompress會將其轉換回輸入資料型別(例如float32)。- 示例:
>>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)
- torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_hook(process_group, bucket)[source]#
警告:此 API 尚處於實驗階段,需要 NCCL 版本大於 2.9.6。
這個 DDP 通訊鉤子實現了一種簡單的梯度壓縮方法,它將
GradBucket張量轉換為半精度 Brain 浮點格式 (torch.bfloat16),然後除以程序組大小。它對這些bfloat16梯度張量進行 allreduce。一旦壓縮的梯度張量 allreduce 完成,鏈式回撥函式decompress會將其轉換回輸入資料型別(例如float32)。- 示例:
>>> ddp_model.register_comm_hook(process_group, bf16_compress_hook)
此外,還提供了一個通訊鉤子包裝器,用於支援 fp16_compress_hook() 或 bf16_compress_hook() 作為包裝器,可以與其他通訊鉤子結合使用。
- torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_wrapper(hook)[source]#
將輸入張量轉換為
torch.float16,將鉤子的結果轉換回輸入資料型別。此包裝器將給定 DDP 通訊鉤子的輸入梯度張量轉換為半精度浮點格式 (
torch.float16),並將給定鉤子的結果張量轉換回輸入資料型別,例如float32。因此,fp16_compress_hook等同於fp16_compress_wrapper(allreduce_hook)。- 示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10) >>> ddp_model.register_comm_hook(state, fp16_compress_wrapper(powerSGD_hook))
- 返回型別
Callable[[Any, GradBucket], Future[Tensor]]
- torch.distributed.algorithms.ddp_comm_hooks.default_hooks.bf16_compress_wrapper(hook)[source]#
警告:此 API 尚處於實驗階段,需要 NCCL 版本大於 2.9.6。
此包裝器將給定 DDP 通訊鉤子的輸入梯度張量轉換為半精度 Brain 浮點格式 (
torch.bfloat16),並將給定鉤子的結果張量轉換回輸入資料型別,例如float32。因此,
bf16_compress_hook等同於bf16_compress_wrapper(allreduce_hook)。- 示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10) >>> ddp_model.register_comm_hook(state, bf16_compress_wrapper(powerSGD_hook))
- 返回型別
Callable[[Any, GradBucket], Future[Tensor]]
PowerSGD 通訊鉤子#
PowerSGD(Vogels 等人,NeurIPS 2019)是一種梯度壓縮演算法,可以提供非常高的壓縮率並加速頻寬受限的分散式訓練。該演算法需要同時維護一些超引數和內部狀態。因此,PowerSGD 通訊鉤子是一個**有狀態**的鉤子,使用者需要提供一個如下定義的 state 物件。
PowerSGD State#
- class torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.PowerSGDState(process_group, matrix_approximation_rank=1, start_powerSGD_iter=1000, min_compression_rate=2, use_error_feedback=True, warm_start=True, orthogonalization_epsilon=0, random_seed=0, compression_stats_logging_frequency=10000, batch_tensors_with_same_shape=False)[source]#
儲存演算法的超引數和所有梯度在訓練期間的內部狀態。
特別地,
matrix_approximation_rank和start_powerSGD_iter是使用者應該調整的主要超引數。為了提高效能,建議保持二元超引數use_error_feedback和warm_start為 True。matrix_approximation_rank控制壓縮低秩張量的大小,這決定了壓縮率。秩越低,壓縮越強。1.1. 如果
matrix_approximation_rank太低,完整的模型質量需要更多訓練步驟才能達到,或者永遠無法達到並導致準確率損失。1.2. 增加
matrix_approximation_rank會顯著增加壓縮的計算成本,並且準確率可能在超過某個matrix_approximation_rank閾值後不再進一步提高。
為了調整
matrix_approximation_rank,我們建議從 1 開始,並以 2 的倍數增加(類似指數網格搜尋,1, 2, 4, …),直到達到滿意的準確率。通常只使用較小的值 1-4。對於某些 NLP 任務(如原始論文附錄 D 所示),此值已增加到 32。start_powerSGD_iter將 PowerSGD 壓縮推遲到第start_powerSGD_iter步,並在第start_powerSGD_iter步之前執行標準 allreduce。這種**標準 allreduce + PowerSGD** 的混合方案可以有效提高準確率,即使使用相對較小的matrix_approximation_rank。這是因為訓練的初始階段通常對不精確的梯度非常敏感,過早壓縮梯度可能會使訓練很快進入次優軌跡,從而對準確率產生不可逆轉的影響。
為了調整
start_powerSGD_iter,我們建議從總訓練步數的 10% 開始,並將其增加直到達到滿意的準確率。如果訓練中有預熱階段,start_powerSGD_iter通常不應小於預熱步數。min_compression_rate是壓縮層所需的最低壓縮率。由於壓縮會帶來計算開銷,只有當頻寬節省足夠大時,張量才值得壓縮,其中(num_rows + num_cols) * matrix_approximation_rank * min_compression_rate < num_rows * num_cols。如果指定的壓縮率閾值無法滿足,該張量將被直接 allreduce 而不進行壓縮。
一旦 PowerSGD 壓縮開始,每隔
compression_stats_logging_frequency次迭代都會記錄壓縮統計資訊。orthogonalization_epsilon可以是一個非常小的值(例如 1e-8),新增到正交化步驟中的每個歸一化矩陣列中,以防止在任何列全為 0 時發生除零錯誤。如果這個問題已經被防止(例如透過批次歸一化),則建議將 epsilon 設定為 0 以保證準確性。batch_tensors_with_same_shape控制是否將相同形狀的張量進行批處理操作以壓縮和解壓縮,以實現更高的並行度。請注意,您還應該增加分桶大小(即 DDP 建構函式中的bucket_cap_mb引數),以便在同一個分桶中出現更多相同形狀的張量,但這可能會降低計算和通訊之間的重疊,並因堆疊相同形狀的張量而增加記憶體佔用。如果壓縮/解壓縮計算是瓶頸,則將其設定為True。
警告
如果啟用了錯誤反饋或預熱,DDP 中允許的
start_powerSGD_iter的最小值是 2。這是因為 DDP 中還有一個內部最佳化會在迭代 1 時重建分桶,這可能會與在重建過程之前記住的任何張量發生衝突。
PowerSGD 鉤子#
警告
PowerSGD 通常需要與模型梯度相同大小的額外記憶體來支援錯誤反饋,這可以補償有偏的壓縮通訊並提高準確性。
警告
PowerSGD 鉤子可能與 Apex 自動混合精度包衝突。請改用 PyTorch 的原生自動混合精度包。
- torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.powerSGD_hook(state, bucket)[source]#
實現 PowerSGD 演算法。
這個 DDP 通訊鉤子實現了 論文 中描述的 PowerSGD 梯度壓縮演算法。一旦梯度張量在所有工作節點上聚合完畢,該鉤子按如下方式應用壓縮:
將輸入的扁平化一維梯度張量視為按引數劃分的張量列表,並將所有張量分為兩組:
1.1. 在 allreduce 之前應被壓縮的張量,因為壓縮可以在頻寬節省方面提供足夠的收益。
1.2. 其餘的張量將被直接 allreduce 而不壓縮,包括所有向量張量(用於偏置)。
處理未壓縮的張量
2.1. 為這些未壓縮的張量分配連續記憶體,並將所有未壓縮的張量作為一批進行 allreduce,不進行壓縮;
2.2. 將單個未壓縮的張量從連續記憶體複製回輸入張量。
處理應被 PowerSGD 壓縮的張量
3.1. 對於每個張量 M,建立兩個低秩張量 P 和 Q 來分解 M,使得 M = PQ^T,其中 Q 從標準正態分佈初始化並正交化;
3.2. 計算 Ps 中的每個 P,它等於 MQ;
3.3. 將 Ps 作為一批進行 allreduce;
3.4. 正交化 Ps 中的每個 P;
3.5. 計算 Qs 中的每個 Q,它約等於 M^TP;
3.6. 將 Qs 作為一批進行 allreduce;
3.7. 計算壓縮張量中的每個 M,它約等於 PQ^T。
請注意,此通訊鉤子在前
state.start_powerSGD_iter次迭代中強制執行標準 allreduce。這不僅使使用者能夠更好地控制速度提升和準確性之間的權衡,還有助於為未來的通訊鉤子開發者抽象化 DDP 的內部最佳化。- 引數
state (PowerSGDState) – 用於配置壓縮率並支援錯誤反饋、預熱等的 state 資訊。要調整壓縮配置,主要需要調整
matrix_approximation_rank、start_powerSGD_iter和min_compression_rate。bucket (dist.GradBucket) – 儲存批處理多個按變數張量的 1D 扁平化梯度張量的分桶。請注意,由於 DDP comm hook 只支援單程序單裝置模式,因此此分桶中只儲存一個張量。
- 返回
通訊的 Future 處理程式,它會就地更新梯度。
- 返回型別
- 示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10, min_compression_rate=0.5) >>> ddp_model.register_comm_hook(state, powerSGD_hook)
- torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.batched_powerSGD_hook(state, bucket)[source]#
實現簡化的 PowerSGD 演算法。
這個 DDP 通訊鉤子實現了 論文 中描述的簡化的 PowerSGD 梯度壓縮演算法。此變體不是逐層壓縮梯度,而是壓縮批處理所有梯度的扁平化輸入張量。因此,它比
powerSGD_hook()**更快**,但通常結果是**準確率低得多**,除非matrix_approximation_rank為 1。警告
在這裡增加
matrix_approximation_rank可能不一定會提高準確率,因為在沒有行/列對齊的情況下批處理按引數張量可能會破壞低秩結構。因此,使用者應始終首先考慮powerSGD_hook(),只有當matrix_approximation_rank為 1 時能達到令人滿意的準確率時,才考慮此變體。一旦梯度張量在所有工作節點上聚合完畢,該鉤子按如下方式應用壓縮:
將輸入的扁平化一維梯度張量視為一個帶有 0 填充的方形張量 M;
建立兩個低秩張量 P 和 Q 來分解 M,使得 M = PQ^T,其中 Q 從標準正態分佈初始化並正交化;
計算 P,它等於 MQ;
allreduce P;
正交化 P;
計算 Q,它約等於 M^TP;
allreduce Q;
計算 M,它約等於 PQ^T。
將輸入張量截斷到原始長度。
請注意,此通訊鉤子在前
state.start_powerSGD_iter次迭代中強制執行標準 allreduce。這不僅使使用者能夠更好地控制速度提升和準確性之間的權衡,還有助於為未來的通訊鉤子開發者抽象化 DDP 的內部最佳化。- 引數
state (PowerSGDState) – 用於配置壓縮率並支援錯誤反饋、預熱等的 state 資訊。要調整壓縮配置,主要需要調整
matrix_approximation_rank和start_powerSGD_iter。bucket (dist.GradBucket) – 儲存批處理多個按變數張量的 1D 扁平化梯度張量的分桶。請注意,由於 DDP comm hook 只支援單程序單裝置模式,因此此分桶中只儲存一個張量。
- 返回
通訊的 Future 處理程式,它會就地更新梯度。
- 返回型別
- 示例:
>>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1) >>> ddp_model.register_comm_hook(state, batched_powerSGD_hook)
除錯通訊鉤子#
顧名思義,除錯通訊鉤子**僅**用於除錯和效能最佳化目的。
警告
除錯通訊鉤子不一定輸出正確的結果。
- torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks.noop_hook(_, bucket)[source]#
返回一個包裝輸入張量的 Future,因此它是一個無操作(no-op),不會產生任何通訊開銷。
此鉤子**僅**用於 allreduce 最佳化的餘量分析,而不是正常的梯度同步。例如,如果註冊此鉤子後訓練時間僅觀察到不到 10% 的加速,這通常意味著 allreduce 在此情況下不是效能瓶頸。這種儀器化在 GPU 軌跡難以檢索或軌跡分析因 allreduce 與計算的重疊或跨程序的失步等因素而變得複雜時尤其有用。
- 示例:
>>> ddp_model.register_comm_hook(None, noop_hook)
通訊鉤子的檢查點#
有狀態的通訊鉤子可以作為模型檢查點的一部分進行儲存,以實現訓練器的重新啟動。要使鉤子可序列化,應定義 __setstate__ 和 __getstate__。
警告
__getstate__ 應從返回的字典中排除不可序列化屬性。
警告
__setstate__ 應正確初始化從提供的 state 中排除的不可序列化屬性。
PowerSGDState 已實現 __setstate__ 和 __getstate__,可作為參考。
- class torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook.PowerSGDState[source]
下面是一個儲存和重新載入 PowerSGD state 和 hook 的簡單端到端示例。
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(24,24)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(24,12)
def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def run_demo(demo_fn, world_size):
mp.spawn(
demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
def demo_serialization(rank, world_size):
setup(rank, world_size)
CHECKPOINT = tempfile.gettempdir() + "/checkpoint.pt"
model = SimpleModel().to(rank)
ddp_model = DistributedDataParallel(model, device_ids=[rank])
powersgd_hook = powerSGD.powerSGD_hook
powersgd_state = powerSGD.PowerSGDState(process_group=None)
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
ddp_model.register_comm_hook(powersgd_state, powersgd_hook)
state = {
'state_dict': ddp_model.state_dict(),
'comm_hook': powersgd_hook,
'comm_hook_state': powersgd_state}
if rank == 0:
torch.save(state, CHECKPOINT)
dist.barrier()
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
checkpoint = torch.load(CHECKPOINT, map_location=map_location)
new_ddp_model = DistributedDataParallel(SimpleModel().to(rank), device_ids=[rank])
new_ddp_model.load_state_dict(checkpoint['state_dict'])
powersgd_hook = checkpoint['comm_hook']
powersgd_state = checkpoint['comm_hook_state']
new_ddp_model.register_comm_hook(powersgd_state, powersgd_hook)
if rank == 0:
os.remove(CHECKPOINT)
cleanup()
if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
world_size = n_gpus
run_demo(demo_serialization, world_size)