探索 TorchRec 分片#
建立日期:2022 年 5 月 10 日 | 最後更新日期:2022 年 5 月 13 日 | 最後驗證日期:2024 年 11 月 05 日
本教程將主要介紹透過 EmbeddingPlanner 和 DistributedModelParallel API 對嵌入表進行分片的方法,並透過顯式配置來探索不同分片方案對嵌入表的優勢。
安裝#
要求:- python >= 3.7
在使用 torchRec 時,我們強烈推薦使用 CUDA。如果使用 CUDA:- cuda >= 11.0
# install conda to make installying pytorch with cudatoolkit 11.3 easier.
!sudo rm Miniconda3-py37_4.9.2-Linux-x86_64.sh Miniconda3-py37_4.9.2-Linux-x86_64.sh.*
!sudo wget https://repo.anaconda.com/miniconda/Miniconda3-py37_4.9.2-Linux-x86_64.sh
!sudo chmod +x Miniconda3-py37_4.9.2-Linux-x86_64.sh
!sudo bash ./Miniconda3-py37_4.9.2-Linux-x86_64.sh -b -f -p /usr/local
# install pytorch with cudatoolkit 11.3
!sudo conda install pytorch cudatoolkit=11.3 -c pytorch-nightly -y
安裝 torchRec 也會安裝 FBGEMM,這是一組 CUDA 核心和 GPU 啟用操作,用於執行
# install torchrec
!pip3 install torchrec-nightly
安裝 multiprocess,它與 ipython 配合使用,可在 colab 中進行多程序程式設計
!pip3 install multiprocess
Colab 執行時需要以下步驟才能檢測到新新增的共享庫。執行時會在 /usr/lib 中搜索共享庫,因此我們將安裝在 /usr/local/lib/ 中的庫複製過去。這是非常必要的一步,僅在 colab 執行時執行。
!sudo cp /usr/local/lib/lib* /usr/lib/
此時請重啟您的執行時,以便新安裝的包能夠被識別。重啟後立即執行以下步驟,以便 python 知道在哪裡查詢包。重啟執行時後,請務必執行此步驟。
import sys
sys.path = ['', '/env/python', '/usr/local/lib/python37.zip', '/usr/local/lib/python3.7', '/usr/local/lib/python3.7/lib-dynload', '/usr/local/lib/python3.7/site-packages', './.local/lib/python3.7/site-packages']
分散式設定#
由於 Notebook 環境,我們無法在此處執行 SPMD 程式,但我們可以在 Notebook 中進行多程序處理來模擬設定。使用者在使用 Torchrec 時應負責設定自己的 SPMD 啟動器。我們設定了我們的環境,以便基於 torch distributed 的通訊後端能夠工作。
import os
import torch
import torchrec
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
構建我們的嵌入模型#
這裡我們使用 TorchRec 提供的 EmbeddingBagCollection 來構建包含嵌入表的嵌入袋模型。
在此,我們建立了一個包含四個嵌入袋的 EmbeddingBagCollection (EBC)。我們有兩種型別的表:大表和小表,它們之間的大小差異為 4096 對 1024。每張表仍然由 64 維的嵌入表示。
我們為表配置了 ParameterConstraints 資料結構,它為模型並行 API 提供了提示,以幫助決定表的分佈和放置策略。在 TorchRec 中,我們支援:* table-wise:將整個表放置在一個裝置上;* row-wise:按行維度均勻地分片表,並將一個分片放置在通訊域的每個裝置上;* column-wise:按嵌入維度均勻地分片表,並將一個分片放置在通訊域的每個裝置上;* table-row-wise:一種針對可用快速主機內裝置互連(如 NVLink)進行主機內通訊最佳化的特殊分片方式;* data_parallel:為每個裝置複製表;
請注意,我們最初將 EBC 分配到“meta”裝置上。這將指示 EBC 暫不分配記憶體。
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.types import ShardingType
from typing import Dict
large_table_cnt = 2
small_table_cnt = 2
large_tables=[
torchrec.EmbeddingBagConfig(
name="large_table_" + str(i),
embedding_dim=64,
num_embeddings=4096,
feature_names=["large_table_feature_" + str(i)],
pooling=torchrec.PoolingType.SUM,
) for i in range(large_table_cnt)
]
small_tables=[
torchrec.EmbeddingBagConfig(
name="small_table_" + str(i),
embedding_dim=64,
num_embeddings=1024,
feature_names=["small_table_feature_" + str(i)],
pooling=torchrec.PoolingType.SUM,
) for i in range(small_table_cnt)
]
def gen_constraints(sharding_type: ShardingType = ShardingType.TABLE_WISE) -> Dict[str, ParameterConstraints]:
large_table_constraints = {
"large_table_" + str(i): ParameterConstraints(
sharding_types=[sharding_type.value],
) for i in range(large_table_cnt)
}
small_table_constraints = {
"small_table_" + str(i): ParameterConstraints(
sharding_types=[sharding_type.value],
) for i in range(small_table_cnt)
}
constraints = {**large_table_constraints, **small_table_constraints}
return constraints
ebc = torchrec.EmbeddingBagCollection(
device="cuda",
tables=large_tables + small_tables
)
多程序中的 DistributedModelParallel#
現在,我們有一個單程序執行函式,用於模擬 SPMD 執行期間一個 rank 的工作。
此程式碼將集體地與其他程序分片模型,並相應地分配記憶體。它首先設定程序組,使用 planner 進行嵌入表放置,並使用 DistributedModelParallel 生成分片模型。
def single_rank_execution(
rank: int,
world_size: int,
constraints: Dict[str, ParameterConstraints],
module: torch.nn.Module,
backend: str,
) -> None:
import os
import torch
import torch.distributed as dist
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.types import ModuleSharder, ShardingEnv
from typing import cast
def init_distributed_single_host(
rank: int,
world_size: int,
backend: str,
# pyre-fixme[11]: Annotation `ProcessGroup` is not defined as a type.
) -> dist.ProcessGroup:
os.environ["RANK"] = f"{rank}"
os.environ["WORLD_SIZE"] = f"{world_size}"
dist.init_process_group(rank=rank, world_size=world_size, backend=backend)
return dist.group.WORLD
if backend == "nccl":
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
topology = Topology(world_size=world_size, compute_device="cuda")
pg = init_distributed_single_host(rank, world_size, backend)
planner = EmbeddingShardingPlanner(
topology=topology,
constraints=constraints,
)
sharders = [cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder())]
plan: ShardingPlan = planner.collective_plan(module, sharders, pg)
sharded_model = DistributedModelParallel(
module,
env=ShardingEnv.from_process_group(pg),
plan=plan,
sharders=sharders,
device=device,
)
print(f"rank:{rank},sharding plan: {plan}")
return sharded_model
多程序執行#
現在讓我們在代表多個 GPU rank 的多程序中執行程式碼。
import multiprocess
def spmd_sharing_simulation(
sharding_type: ShardingType = ShardingType.TABLE_WISE,
world_size = 2,
):
ctx = multiprocess.get_context("spawn")
processes = []
for rank in range(world_size):
p = ctx.Process(
target=single_rank_execution,
args=(
rank,
world_size,
gen_constraints(sharding_type),
ebc,
"nccl"
),
)
p.start()
processes.append(p)
for p in processes:
p.join()
assert 0 == p.exitcode
表分片#
現在讓我們在兩個程序(用於 2 個 GPU)中執行程式碼。在列印的 plan 中,我們可以看到我們的表是如何跨 GPU 分片的。每個節點將擁有一個大表和一個小表,這表明我們的 planner 正在嘗試對嵌入表進行負載均衡。對於許多小型中型表,表分片是實現負載均衡的預設首選方案。
spmd_sharing_simulation(ShardingType.TABLE_WISE)
rank:1,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:0/cuda:0)])), 'large_table_1': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 64], placement=rank:0/cuda:0)])), 'small_table_1': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 64], placement=rank:1/cuda:1)]))}}
rank:0,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:0/cuda:0)])), 'large_table_1': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 64], placement=rank:0/cuda:0)])), 'small_table_1': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 64], placement=rank:1/cuda:1)]))}}
探索其他分片模式#
我們已經初步瞭解了表分片的外觀以及它如何平衡表的放置。現在我們來探索更側重於負載均衡的分片模式:行分片。行分片專門針對單個裝置無法容納的大表,因為大表有更多的行數導致記憶體佔用增加。它可以解決模型中超大表的放置問題。使用者可以在列印的 plan 日誌的 shard_sizes 部分看到,表按行維度被一分為二,分佈到兩個 GPU 上。
spmd_sharing_simulation(ShardingType.ROW_WISE)
rank:1,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2048, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2048, 0], shard_sizes=[2048, 64], placement=rank:1/cuda:1)])), 'large_table_1': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2048, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2048, 0], shard_sizes=[2048, 64], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[512, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[512, 0], shard_sizes=[512, 64], placement=rank:1/cuda:1)])), 'small_table_1': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[512, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[512, 0], shard_sizes=[512, 64], placement=rank:1/cuda:1)]))}}
rank:0,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2048, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2048, 0], shard_sizes=[2048, 64], placement=rank:1/cuda:1)])), 'large_table_1': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2048, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2048, 0], shard_sizes=[2048, 64], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[512, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[512, 0], shard_sizes=[512, 64], placement=rank:1/cuda:1)])), 'small_table_1': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[512, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[512, 0], shard_sizes=[512, 64], placement=rank:1/cuda:1)]))}}
另一方面,列分片解決了具有大嵌入維度表的負載不平衡問題。我們將垂直分割表。使用者可以在列印的 plan 日誌的 shard_sizes 部分看到,表按嵌入維度被一分為二,分佈到兩個 GPU 上。
spmd_sharing_simulation(ShardingType.COLUMN_WISE)
rank:0,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[4096, 32], placement=rank:1/cuda:1)])), 'large_table_1': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[4096, 32], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[1024, 32], placement=rank:1/cuda:1)])), 'small_table_1': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[1024, 32], placement=rank:1/cuda:1)]))}}
rank:1,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[4096, 32], placement=rank:1/cuda:1)])), 'large_table_1': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[4096, 32], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[1024, 32], placement=rank:1/cuda:1)])), 'small_table_1': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[1024, 32], placement=rank:1/cuda:1)]))}}
對於 table-row-wise,由於其在多主機設定下的特性,我們不幸無法模擬它。將來我們將提供一個 Python SPMD 示例來訓練使用 table-row-wise 的模型。
透過資料並行,我們將為所有裝置重複表。
spmd_sharing_simulation(ShardingType.DATA_PARALLEL)
rank:0,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'large_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None)}}
rank:1,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'large_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None)}}