CommDebugMode 入門#
創建於:2024 年 8 月 19 日 | 最後更新:2024 年 10 月 8 日 | 最後驗證:2024 年 11 月 5 日
作者:Anshul Sinha
在本教程中,我們將探討如何在 PyTorch 的 DistributedTensor (DTensor) 中使用 CommDebugMode 進行除錯,透過在分散式訓練環境中跟蹤集體操作。
先決條件#
Python 3.8 - 3.11
PyTorch 2.2 或更高版本
什麼是 CommDebugMode 以及它有何用處#
隨著模型規模的不斷增大,使用者正尋求利用各種並行策略的組合來擴充套件分散式訓練。然而,現有解決方案之間缺乏互操作性帶來了嚴峻的挑戰,主要原因是缺少能夠連線這些不同並行策略的統一抽象。為了解決這個問題,PyTorch 推出了 DistributedTensor(DTensor),它抽象了分散式訓練中張量通訊的複雜性,提供了無縫的使用者體驗。然而,當處理現有的並行解決方案和使用 DTensor 等統一抽象開發並行解決方案時,由於缺乏對底層集體通訊發生的時間和原因的透明度,可能會使高階使用者在識別和解決問題時面臨挑戰。為了應對這一挑戰,Python 上下文管理器 CommDebugMode 將作為 DTensor 的主要除錯工具之一,使使用者能夠檢視在使用 DTensor 時集體操作何時以及為何發生,從而有效地解決這個問題。
使用 CommDebugMode#
以下是使用 CommDebugMode 的方法
# The model used in this example is a MLPModule applying Tensor Parallel
comm_mode = CommDebugMode()
with comm_mode:
output = model(inp)
# print the operation level collective tracing information
print(comm_mode.generate_comm_debug_tracing_table(noise_level=0))
# log the operation level collective tracing information to a file
comm_mode.log_comm_debug_tracing_table_to_file(
noise_level=1, file_name="transformer_operation_log.txt"
)
# dump the operation level collective tracing information to json file,
# used in the visual browser below
comm_mode.generate_json_dump(noise_level=2)
這是 MLPModule 在噪聲級別 0 下的輸出示例
Expected Output:
Global
FORWARD PASS
*c10d_functional.all_reduce: 1
MLPModule
FORWARD PASS
*c10d_functional.all_reduce: 1
MLPModule.net1
MLPModule.relu
MLPModule.net2
FORWARD PASS
*c10d_functional.all_reduce: 1
要使用 CommDebugMode,您必須將執行模型的程式碼包裝在 CommDebugMode 中,並呼叫您想要用來顯示資料的 API。您還可以使用 noise_level 引數來控制顯示資訊的詳細程度。以下是每個噪聲級別顯示的內容:
在上面的示例中,您可以看到集體操作 all_reduce 在 MLPModule 的前向傳播中發生了一次。此外,您可以使用 CommDebugMode 來精確地指出 all_reduce 操作發生在 MLPModule 的第二個線性層中。
以下是您可以用於上傳自己的 JSON 檔案的互動式模組樹視覺化工具
結論#
在本教程中,我們學習瞭如何使用 CommDebugMode 來除錯使用 PyTorch 的通訊集體操作的 Distributed Tensors 和並行解決方案。您可以在嵌入式視覺化瀏覽器中使用自己的 JSON 輸出。
有關 CommDebugMode 的更詳細資訊,請參閱 comm_mode_features_example.py