評價此頁

TorchRec 簡介#

建立日期: 2024年10月02日 | 最後更新: 2025年07月10日 | 最後驗證: 2024年10月02日

TorchRec 是一個針對使用嵌入式(embeddings)構建可擴充套件、高效推薦系統的 PyTorch 庫。本教程將引導您完成安裝過程,介紹嵌入式的概念,並強調其在推薦系統中的重要性。教程將提供使用 PyTorch 和 TorchRec 實現嵌入式的實踐演示,重點關注透過分散式訓練和高階最佳化來處理大型嵌入表。

您將學到什麼
  • 嵌入式的基本原理及其在推薦系統中的作用

  • 如何在 PyTorch 環境中設定 TorchRec 來管理和實現嵌入式

  • 探索將大型嵌入表分佈到多個 GPU 上的高階技術

先決條件
  • PyTorch v2.5 或更高版本,以及 CUDA 11.8 或更高版本

  • Python 3.9 或更高版本

  • FBGEMM

安裝依賴項#

在 Google Colab 中執行本教程之前,請確保安裝以下依賴項

!pip3 install --pre torch --index-url https://download.pytorch.org/whl/cu121 -U
!pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/cu121
!pip3 install torchmetrics==1.0.3
!pip3 install torchrec --index-url https://download.pytorch.org/whl/cu121

注意

如果您在 Google Colab 中執行,請確保切換到 GPU 執行時型別。有關更多資訊,請參閱 啟用 CUDA

嵌入式(Embeddings)#

在構建推薦系統時,類別特徵(categorical features)通常具有巨大的基數(cardinality),例如帖子、使用者、廣告等等。

為了表示這些實體和模擬這些關係,我們使用 嵌入式(embeddings)。在機器學習中,嵌入式是在高維空間中表示實數向量,用於表示單詞、影像或使用者等複雜資料中的含義

推薦系統中的嵌入式#

現在您可能會問,這些嵌入式是如何生成的?嗯,嵌入式表示為 嵌入表(Embedding Table) 中的單獨行,也稱為嵌入權重。之所以如此,是因為嵌入式或嵌入表權重與模型中的其他權重一樣,都透過梯度下降進行訓練!

嵌入表只是一個用於儲存嵌入式的大型矩陣,具有兩個維度(B, N),其中

  • B 是表中儲存的嵌入式數量

  • N 是每個嵌入式的維度數(N 維嵌入式)。

嵌入表的輸入代表嵌入查詢,用於檢索特定索引或行的嵌入式。在許多大型系統中使用的推薦系統中,唯一的 ID 不僅用於特定使用者,還跨越帖子和廣告等實體,用作相應嵌入表的查詢索引!

嵌入式在推薦系統中透過以下過程進行訓練

  • 輸入/查詢索引被作為唯一 ID 輸入模型。ID 會被雜湊到嵌入表的總大小,以防止出現 ID > 行數的問題。

  • 然後檢索嵌入式並進行 池化(pooling),例如取嵌入式的總和或平均值。這是必需的,因為每個示例的嵌入式數量可能不同,而模型需要一致的形狀。

  • 嵌入式與模型的其餘部分一起用於生成預測,例如廣告的 點選率 (CTR)

  • 根據預測和示例的標籤計算損失,並且 模型的所有權重都透過梯度下降和反向傳播進行更新,包括與該示例關聯的嵌入權重

這些嵌入式對於表示使用者、帖子和廣告等類別特徵至關重要,以便捕獲關係並做出好的推薦。 深度學習推薦模型 (DLRM) 論文更詳細地討論了在推薦系統中使嵌入表的技術細節。

本教程介紹了嵌入式的概念,展示了 TorchRec 特定的模組和資料型別,並說明了 TorchRec 的分散式訓練是如何工作的。

import torch

PyTorch 中的嵌入式#

在 PyTorch 中,我們有以下型別的嵌入式:

  • torch.nn.Embedding:一種嵌入表,其前向傳播返回嵌入式本身。

  • torch.nn.EmbeddingBag:嵌入表,其前向傳播返回然後被池化的嵌入式,例如總和或平均值,也稱為 池化嵌入式 (Pooled Embeddings)

在本節中,我們將簡要介紹透過將索引傳遞到表中來執行嵌入查詢。

num_embeddings, embedding_dim = 10, 4

# Initialize our embedding table
weights = torch.rand(num_embeddings, embedding_dim)
print("Weights:", weights)

# Pass in pre-generated weights just for example, typically weights are randomly initialized
embedding_collection = torch.nn.Embedding(
    num_embeddings, embedding_dim, _weight=weights
)
embedding_bag_collection = torch.nn.EmbeddingBag(
    num_embeddings, embedding_dim, _weight=weights
)

# Print out the tables, we should see the same weights as above
print("Embedding Collection Table: ", embedding_collection.weight)
print("Embedding Bag Collection Table: ", embedding_bag_collection.weight)

# Lookup rows (ids for embedding ids) from the embedding tables
# 2D tensor with shape (batch_size, ids for each batch)
ids = torch.tensor([[1, 3]])
print("Input row IDS: ", ids)

embeddings = embedding_collection(ids)

# Print out the embedding lookups
# You should see the specific embeddings be the same as the rows (ids) of the embedding tables above
print("Embedding Collection Results: ")
print(embeddings)
print("Shape: ", embeddings.shape)

# ``nn.EmbeddingBag`` default pooling is mean, so should be mean of batch dimension of values above
pooled_embeddings = embedding_bag_collection(ids)

print("Embedding Bag Collection Results: ")
print(pooled_embeddings)
print("Shape: ", pooled_embeddings.shape)

# ``nn.EmbeddingBag`` is the same as ``nn.Embedding`` but just with pooling (mean, sum, and so on)
# We can see that the mean of the embeddings of embedding_collection is the same as the output of the embedding_bag_collection
print("Mean: ", torch.mean(embedding_collection(ids), dim=1))

恭喜!您現在對如何使用嵌入表有了基本的瞭解——這是現代推薦系統的基礎之一!這些表代表實體及其關係。例如,給定使用者與他們喜歡過的頁面和帖子的關係。

TorchRec 功能概述#

在上面一節中,我們學習瞭如何使用嵌入表,這是現代推薦系統的基礎之一!這些表代表實體和關係,例如使用者、頁面、帖子等。鑑於這些實體不斷增加,通常會應用 雜湊 函式來確保 ID 在特定嵌入表的範圍內。但是,為了表示大量實體並減少雜湊衝突,這些表可能會變得非常龐大(想想廣告的數量)。事實上,這些表可能變得如此龐大,以至於即使有 80GB 記憶體也無法容納在單個 GPU 上。

為了訓練具有龐大嵌入表的模型,需要將這些表分片到 GPU 上,這會帶來全新的並行化和最佳化問題和機遇。幸運的是,我們有 TorchRec 庫 <https://docs.pytorch.com.tw/torchrec/overview.html>`__,它已經遇到、整合並解決了其中許多問題。TorchRec 是一個 提供大規模分散式嵌入式原語的庫

接下來,我們將探索 TorchRec 庫的主要功能。我們將從 torch.nn.Embedding 開始,然後擴充套件到自定義 TorchRec 模組,探索分散式訓練環境,為嵌入式生成分片計劃,檢視固有的 TorchRec 最佳化,並將模型擴充套件為可在 C++ 中進行推理。下面是本節內容的快速大綱

  • TorchRec 模組和資料型別

  • 分散式訓練、分片和最佳化

讓我們開始匯入 TorchRec

import torchrec

本節將介紹 TorchRec 模組和資料型別,包括 EmbeddingCollectionEmbeddingBagCollectionJaggedTensorKeyedJaggedTensorKeyedTensor 等實體。

EmbeddingBagEmbeddingBagCollection#

我們已經瞭解了 torch.nn.Embeddingtorch.nn.EmbeddingBag。TorchRec 透過建立嵌入式集合來擴充套件這些模組,換句話說,就是可以擁有多個嵌入表的模組,即 EmbeddingCollectionEmbeddingBagCollection。我們將使用 EmbeddingBagCollection 來表示一組嵌入式包。

在下面的示例程式碼中,我們建立了一個 EmbeddingBagCollection (EBC),其中包含兩個嵌入式包,一個代表 產品,一個代表 使用者。每個表,product_tableuser_table,都由一個 64 維、大小為 4096 的嵌入式表示。

ebc = torchrec.EmbeddingBagCollection(
    device="cpu",
    tables=[
        torchrec.EmbeddingBagConfig(
            name="product_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["product"],
            pooling=torchrec.PoolingType.SUM,
        ),
        torchrec.EmbeddingBagConfig(
            name="user_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["user"],
            pooling=torchrec.PoolingType.SUM,
        ),
    ],
)
print(ebc.embedding_bags)

讓我們檢查 EmbeddingBagCollection 的前向方法以及模組的輸入和輸出。

import inspect

# Let's look at the ``EmbeddingBagCollection`` forward method
# What is a ``KeyedJaggedTensor`` and ``KeyedTensor``?
print(inspect.getsource(ebc.forward))

TorchRec 輸入/輸出資料型別#

TorchRec 為其模組的輸入和輸出提供了不同的資料型別:JaggedTensorKeyedJaggedTensorKeyedTensor。現在您可能會問,為什麼要建立新的資料型別來表示稀疏特徵?要回答這個問題,我們必須瞭解稀疏特徵在程式碼中是如何表示的。

稀疏特徵也稱為 id_list_featureid_score_list_feature,它們是將被用作嵌入表索引以檢索該 ID 的嵌入式的 ID。舉一個非常簡單的例子,想象一下一個稀疏特徵是使用者與之互動過的廣告。輸入本身將是一組使用者與之互動過的廣告 ID,檢索到的嵌入式將是對這些廣告的語義表示。在程式碼中表示這些特徵的棘手之處在於,在每個輸入示例中,ID 的數量是可變的。有一天,使用者可能只與一個廣告互動,而第二天他們可能與三個廣告互動。

下面展示了一個簡單的表示,其中有一個 lengths 張量,表示一個批次中的每個示例有多少索引,以及一個包含索引本身的 values 張量。

# Batch Size 2
# 1 ID in example 1, 2 IDs in example 2
id_list_feature_lengths = torch.tensor([1, 2])

# Values (IDs) tensor: ID 5 is in example 1, ID 7, 1 is in example 2
id_list_feature_values = torch.tensor([5, 7, 1])

接下來,讓我們看看偏移量以及每個批次包含的內容。

# Lengths can be converted to offsets for easy indexing of values
id_list_feature_offsets = torch.cumsum(id_list_feature_lengths, dim=0)

print("Offsets: ", id_list_feature_offsets)
print("First Batch: ", id_list_feature_values[: id_list_feature_offsets[0]])
print(
    "Second Batch: ",
    id_list_feature_values[id_list_feature_offsets[0] : id_list_feature_offsets[1]],
)

from torchrec import JaggedTensor

# ``JaggedTensor`` is just a wrapper around lengths/offsets and values tensors!
jt = JaggedTensor(values=id_list_feature_values, lengths=id_list_feature_lengths)

# Automatically compute offsets from lengths
print("Offsets: ", jt.offsets())

# Convert to list of values
print("List of Values: ", jt.to_dense())

# ``__str__`` representation
print(jt)

from torchrec import KeyedJaggedTensor

# ``JaggedTensor`` represents IDs for 1 feature, but we have multiple features in an ``EmbeddingBagCollection``
# That's where ``KeyedJaggedTensor`` comes in! ``KeyedJaggedTensor`` is just multiple ``JaggedTensors`` for multiple id_list_feature_offsets
# From before, we have our two features "product" and "user". Let's create ``JaggedTensors`` for both!

product_jt = JaggedTensor(
    values=torch.tensor([1, 2, 1, 5]), lengths=torch.tensor([3, 1])
)
user_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2]))

# Q1: How many batches are there, and which values are in the first batch for ``product_jt`` and ``user_jt``?
kjt = KeyedJaggedTensor.from_jt_dict({"product": product_jt, "user": user_jt})

# Look at our feature keys for the ``KeyedJaggedTensor``
print("Keys: ", kjt.keys())

# Look at the overall lengths for the ``KeyedJaggedTensor``
print("Lengths: ", kjt.lengths())

# Look at all values for ``KeyedJaggedTensor``
print("Values: ", kjt.values())

# Can convert ``KeyedJaggedTensor`` to dictionary representation
print("to_dict: ", kjt.to_dict())

# ``KeyedJaggedTensor`` string representation
print(kjt)

# Q2: What are the offsets for the ``KeyedJaggedTensor``?

# Now we can run a forward pass on our ``EmbeddingBagCollection`` from before
result = ebc(kjt)
result

# Result is a ``KeyedTensor``, which contains a list of the feature names and the embedding results
print(result.keys())

# The results shape is [2, 128], as batch size of 2. Reread previous section if you need a refresher on how the batch size is determined
# 128 for dimension of embedding. If you look at where we initialized the ``EmbeddingBagCollection``, we have two tables "product" and "user" of dimension 64 each
# meaning embeddings for both features are of size 64. 64 + 64 = 128
print(result.values().shape)

# Nice to_dict method to determine the embeddings that belong to each feature
result_dict = result.to_dict()
for key, embedding in result_dict.items():
    print(key, embedding.shape)

恭喜!您現在已經瞭解了 TorchRec 模組和資料型別。為您一路走到這裡給自己鼓掌。接下來,我們將學習分散式訓練和分片。

分散式訓練和分片#

現在我們對 TorchRec 模組和資料型別有了初步瞭解,是時候更進一步了。

請記住,TorchRec 的主要目的是提供分散式嵌入式的原語。到目前為止,我們只在單個裝置上處理了嵌入表。這之所以可行,是因為嵌入表的大小一直很小,但在生產環境中通常不是這樣。嵌入表通常會變得非常龐大,一個表無法容納在單個 GPU 上,這就需要多個裝置和分散式環境。

在本節中,我們將透過 TorchRec 探索設定分散式環境、實際生產訓練的進行方式以及嵌入表的分片。

本節也只使用 1 個 GPU,但它將以分散式方式進行處理。這僅是訓練的限制,因為訓練每個 GPU 都有一個程序。推理不會遇到此要求。

在下面的示例程式碼中,我們設定了 PyTorch 分散式環境。

警告

如果您在 Google Colab 中執行,則只能呼叫此單元格一次,再次呼叫將導致錯誤,因為您只能初始化一次程序組。

import os

import torch.distributed as dist

# Set up environment variables for distributed training
# RANK is which GPU we are on, default 0
os.environ["RANK"] = "0"
# How many devices in our "world", colab notebook can only handle 1 process
os.environ["WORLD_SIZE"] = "1"
# Localhost as we are training locally
os.environ["MASTER_ADDR"] = "localhost"
# Port for distributed training
os.environ["MASTER_PORT"] = "29500"

# nccl backend is for GPUs, gloo is for CPUs
dist.init_process_group(backend="gloo")

print(f"Distributed environment initialized: {dist}")

分散式嵌入式#

我們已經接觸過主要的 TorchRec 模組:EmbeddingBagCollection。我們已經研究了它是如何工作的以及資料在 TorchRec 中是如何表示的。然而,我們尚未探索 TorchRec 的一個主要部分,那就是 分散式嵌入式

GPU 是目前最受歡迎的機器學習工作負載選擇,因為它們能夠處理比 CPU 大幾個數量級的浮點運算/秒(FLOPs)。然而,GPU 的缺點是快速記憶體(HBM,類似於 CPU 的 RAM)有限,通常只有幾十 GB。

推薦系統模型可能包含遠超單個 GPU 記憶體限制的嵌入表,因此需要將嵌入表分佈到多個 GPU 上,也稱為 模型並行。另一方面,資料並行 是指在每個 GPU 上覆制整個模型,每個 GPU 負責處理不同的資料批次進行訓練,並在反向傳播時同步梯度。

計算量需求較低但記憶體需求較高(嵌入式)的模型部分採用模型並行進行分發,而 計算量需求較高但記憶體需求較低(密集層、MLP 等)的模型部分採用資料並行進行分發

分片(Sharding)#

為了分佈嵌入表,我們將嵌入表分成幾部分並將其放置在不同的裝置上,這也稱為“分片”。

有很多方法可以分片嵌入表。最常見的方法是

  • 按表分片 (Table-Wise): 表完全放置在一個裝置上

  • 按列分片 (Column-Wise): 嵌入表的列被分片

  • 按行分片 (Row-Wise): 嵌入表的行被分片

分片模組(Sharded Modules)#

雖然所有這些聽起來都很複雜和難以實現,但您很幸運。TorchRec 提供了所有用於輕鬆分散式訓練和推理的原語!事實上,TorchRec 模組有兩個對應的類,用於在分散式環境中處理任何 TorchRec 模組。

  • 模組分片器 (module sharder): 這個類公開了一個 shard API,它負責分片 TorchRec 模組,生成一個分片模組。* 對於 EmbeddingBagCollection,分片器是 `EmbeddingBagCollectionSharder `

  • 分片模組 (Sharded module): 這個類是 TorchRec 模組的分片變體。它與常規 TorchRec 模組具有相同的輸入/輸出,但經過最佳化,可以在分散式環境中工作。* 對於 EmbeddingBagCollection,分片變體是 ShardedEmbeddingBagCollection

每個 TorchRec 模組都有一個未分片版本和一個分片版本。

  • 未分片版本用於原型設計和實驗。

  • 分片版本用於在分散式環境中進行分散式訓練和推理。

TorchRec 模組的分片版本,例如 EmbeddingBagCollection,將處理模型並行所需的一切,例如 GPU 之間的通訊,以將嵌入式分發到正確的 GPU。

我們 EmbeddingBagCollection 模組的回顧

ebc

from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.types import ShardingEnv

# Corresponding sharder for ``EmbeddingBagCollection`` module
sharder = EmbeddingBagCollectionSharder()

# ``ProcessGroup`` from torch.distributed initialized 2 cells above
pg = dist.GroupMember.WORLD
assert pg is not None, "Process group is not initialized"

print(f"Process Group: {pg}")

規劃器(Planner)#

在我們展示分片如何工作之前,我們必須瞭解 規劃器,它有助於我們確定最佳分片配置。

給定嵌入表的數量和 GPU 的數量,存在許多不同的分片配置。例如,給定 2 個嵌入表和 2 個 GPU,您可以:

  • 將 1 個表放在每個 GPU 上

  • 將兩個表都放在一個 GPU 上,另一個 GPU 上不放任何表

  • 將某些行和列放在每個 GPU 上

考慮到所有這些可能性,我們通常需要一個對效能最優的分片配置。

這就是規劃器發揮作用的地方。規劃器能夠確定給定嵌入表的數量和 GPU 的數量,什麼是最佳配置。事實證明,手動完成這項工作極其困難,工程師需要考慮許多因素才能確保最佳的分片計劃。幸運的是,TorchRec 在使用規劃器時提供了自動規劃器。

TorchRec 規劃器

  • 評估硬體的記憶體限制

  • 根據記憶體獲取(如嵌入查詢)估算計算量

  • 處理特定於資料 的因素

  • 考慮頻寬等其他硬體特定資訊,以生成模型的最佳分片計劃

為了考慮所有這些變數,TorchRec 規劃器可以接受 各種關於嵌入表、約束、硬體資訊和拓撲的資料,以幫助生成模型的最佳分片計劃,該計劃通常在整個堆疊中提供。

要了解更多關於分片的資訊,請參閱我們的 分片教程

# In our case, 1 GPU and compute on CUDA device
planner = EmbeddingShardingPlanner(
    topology=Topology(
        world_size=1,
        compute_device="cuda",
    )
)

# Run planner to get plan for sharding
plan = planner.collective_plan(ebc, [sharder], pg)

print(f"Sharding Plan generated: {plan}")

規劃器結果#

如上所示,執行規劃器時會產生大量輸出。我們可以看到許多正在計算的統計資料以及我們的表最終放置的位置。

執行規劃器的結果是一個靜態計劃,可以用於分片!這允許分片對於生產模型是靜態的,而不是每次都確定新的分片計劃。下面,我們使用分片計劃最終生成我們的 ShardedEmbeddingBagCollection

# The static plan that was generated
plan

env = ShardingEnv.from_process_group(pg)

# Shard the ``EmbeddingBagCollection`` module using the ``EmbeddingBagCollectionSharder``
sharded_ebc = sharder.shard(ebc, plan.plan[""], env, torch.device("cuda"))

print(f"Sharded EBC Module: {sharded_ebc}")

使用 LazyAwaitable 進行 GPU 訓練#

請記住,TorchRec 是一個高度最佳化的分散式嵌入式庫。TorchRec 引入的一個用於提高 GPU 訓練效能的概念是 LazyAwaitable `。您將看到 LazyAwaitable 型別作為各種分片 TorchRec 模組的輸出。 LazyAwaitable 型別所做的就是儘可能延遲計算某個結果,它透過充當非同步型別來做到這一點。

from typing import List

from torchrec.distributed.types import LazyAwaitable


# Demonstrate a ``LazyAwaitable`` type:
class ExampleAwaitable(LazyAwaitable[torch.Tensor]):
    def __init__(self, size: List[int]) -> None:
        super().__init__()
        self._size = size

    def _wait_impl(self) -> torch.Tensor:
        return torch.ones(self._size)


awaitable = ExampleAwaitable([3, 2])
awaitable.wait()

kjt = kjt.to("cuda")
output = sharded_ebc(kjt)
# The output of our sharded ``EmbeddingBagCollection`` module is an `Awaitable`?
print(output)

kt = output.wait()
# Now we have our ``KeyedTensor`` after calling ``.wait()``
# If you are confused as to why we have a ``KeyedTensor ``output,
# give yourself a refresher on the unsharded ``EmbeddingBagCollection`` module
print(type(kt))

print(kt.keys())

print(kt.values().shape)

# Same output format as unsharded ``EmbeddingBagCollection``
result_dict = kt.to_dict()
for key, embedding in result_dict.items():
    print(key, embedding.shape)

分片 TorchRec 模組的結構#

我們現在已經成功地基於生成的分片計劃對 EmbeddingBagCollection 進行了分片!分片模組具有 TorchRec 的通用 API,這些 API 抽象了多個 GPU 之間的分散式通訊/計算。事實上,這些 API 針對訓練和推理的效能進行了高度最佳化。以下是 TorchRec 提供的用於分散式訓練/推理的三個通用 API

  • input_dist: 處理將輸入從 GPU 分發到 GPU。

  • lookups: 使用 FBGEMM TBE 以最佳化、批次的方式執行實際的嵌入查詢(稍後會詳細介紹)。

  • output_dist: 處理將輸出從 GPU 分發到 GPU。

輸入和輸出的分發是透過 NCCL Collectives,即 All-to-All 來完成的,這是所有 GPU 相互之間傳送和接收資料的地方。TorchRec 與 PyTorch 分散式介面進行通訊,併為終端使用者提供簡潔的抽象,消除了對底層細節的擔憂。

反向傳播執行所有這些集合操作,但順序相反,用於分發梯度。input_distlookupoutput_dist 都依賴於分片方案。由於我們是按表分片,因此這些 API 是由 TwPooledEmbeddingSharding 構建的模組。

sharded_ebc

# Distribute input KJTs to all other GPUs and receive KJTs
sharded_ebc._input_dists

# Distribute output embeddings to all other GPUs and receive embeddings
sharded_ebc._output_dists

最佳化嵌入查詢#

在執行一組嵌入表的查詢時,一個簡單的解決方案是迭代所有 nn.EmbeddingBags 併為每個表執行一次查詢。這正是標準、未分片的 EmbeddingBagCollection 所做的。但是,雖然這個解決方案很簡單,但速度非常慢。

FBGEMM 是一個提供 GPU 運算子(也稱為核心)的庫,這些運算子經過高度最佳化。其中一個運算子稱為 Table Batched Embedding (TBE),它提供了兩個主要的最佳化:

  • 表批處理(Table batching),允許您使用一個核心呼叫來查詢多個嵌入式。

  • 最佳化器融合(Optimizer Fusion),允許模組根據規範的 PyTorch 最佳化器和引數自行更新。

ShardedEmbeddingBagCollection 使用 FBGEMM TBE 進行查詢,而不是傳統的 nn.EmbeddingBags,以實現最佳化的嵌入查詢。

sharded_ebc._lookups

DistributedModelParallel#

我們現在已經完成了對單個 EmbeddingBagCollection 的分片!我們能夠使用 EmbeddingBagCollectionSharder 並使用未分片的 EmbeddingBagCollection 來生成 ShardedEmbeddingBagCollection 模組。這個工作流程是可以的,但在實現模型並行時,通常使用 DistributedModelParallel (DMP) 作為標準介面。當使用 DMP 包裝模型(在本例中為 ebc)時,將發生以下情況:

  1. 決定如何分片模型。DMP 將收集可用的分片器,並制定一個關於如何最佳分片嵌入表(例如,EmbeddingBagCollection)的計劃。

  2. 實際分片模型。這包括在適當的裝置上為每個嵌入表分配記憶體。

DMP 接收我們剛才嘗試過的所有內容,例如靜態分片計劃、分片器列表等。然而,它也有一些不錯的預設設定,可以無縫分片 TorchRec 模型。在這個玩具示例中,由於我們有兩個嵌入表和一個 GPU,TorchRec 將兩者都放在單個 GPU 上。

ebc

model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device("cuda"))

out = model(kjt)
out.wait()

model


from fbgemm_gpu.split_embedding_configs import EmbOptimType

分片最佳實踐#

目前,我們的配置只在 1 個 GPU(或 rank)上進行分片,這很簡單:只需將所有表放在 1 個 GPU 的記憶體中。然而,在真實的生產用例中,嵌入表通常在數百個 GPU 上進行分片,採用不同的分片方法,如按表、按行和按列分片。確定一個合適的分片配置(以防止記憶體不足問題)同時保持記憶體和計算的平衡以獲得最佳效能,這是極其重要的。

新增最佳化器#

請記住,TorchRec 模組針對大規模分散式訓練進行了高度最佳化。一個重要的最佳化與最佳化器有關。

TorchRec 模組提供了一個無縫的 API,用於融合反向傳播和訓練中的最佳化步驟,從而顯著提高效能並減少記憶體使用,同時還可以對分配給不同模型引數的最佳化器進行粒度控制。

最佳化器類#

TorchRec 使用 CombinedOptimizer,它包含一組 KeyedOptimizersCombinedOptimizer 有效地簡化了處理模型中各種子組的多個最佳化器的操作。 KeyedOptimizer 擴充套件了 torch.optim.Optimizer,並透過引數字典進行初始化,該字典暴露了引數。 EmbeddingBagCollection 中的每個 TBE 模組都有自己的 KeyedOptimizer,它們組合成一個 CombinedOptimizer

TorchRec 中的融合最佳化器#

使用 DistributedModelParallel 時,最佳化器是融合的,這意味著最佳化器更新在反向傳播中完成。這是 TorchRec 和 FBGEMM 中的一項最佳化,其中最佳化器嵌入式梯度不會被具體化並直接應用於引數。這帶來了顯著的記憶體節省,因為嵌入式梯度通常與引數本身的大小相同。

但是,您可以選擇將最佳化器設定為 dense,這樣就不會應用此最佳化,允許您檢查嵌入式梯度或根據需要對其應用計算。在這種情況下,密集最佳化器將是您的 規範的 PyTorch 模型訓練迴圈與最佳化器

一旦透過 DistributedModelParallel 建立了最佳化器,您仍然需要管理不與 TorchRec 嵌入式模組關聯的其他引數的最佳化器。要找到其他引數,請使用 in_backward_optimizer_filter(model.named_parameters())。像處理普通 Torch 最佳化器一樣為這些引數應用最佳化器,並將此與 model.fused_optimizer 組合成一個 CombinedOptimizer,您可以在訓練迴圈中使用它來執行 zero_gradstep

EmbeddingBagCollection 新增最佳化器#

我們將透過兩種方式執行此操作,這兩種方式是等效的,但根據您的偏好提供選項:

  1. 透過分片器中的 fused_params 傳遞最佳化器關鍵字引數。

  2. 透過 apply_optimizer_in_backward,它將最佳化器引數轉換為 fused_params 以傳遞給 EmbeddingBagCollectionEmbeddingCollection 中的 TBE

# Option 1: Passing optimizer kwargs through fused parameters
from torchrec.optim.optimizers import in_backward_optimizer_filter


# We initialize the sharder with
fused_params = {
    "optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD,
    "learning_rate": 0.02,
    "eps": 0.002,
}

# Initialize sharder with ``fused_params``
sharder_with_fused_params = EmbeddingBagCollectionSharder(fused_params=fused_params)

# We'll use same plan and unsharded EBC as before but this time with our new sharder
sharded_ebc_fused_params = sharder_with_fused_params.shard(
    ebc, plan.plan[""], env, torch.device("cuda")
)

# Looking at the optimizer of each, we can see that the learning rate changed, which indicates our optimizer has been applied correctly.
# If seen, we can also look at the TBE logs of the cell to see that our new optimizer is indeed being applied
print(f"Original Sharded EBC fused optimizer: {sharded_ebc.fused_optimizer}")
print(
    f"Sharded EBC with fused parameters fused optimizer: {sharded_ebc_fused_params.fused_optimizer}"
)

print(f"Type of optimizer: {type(sharded_ebc_fused_params.fused_optimizer)}")

import copy

from torch.distributed.optim import (
    _apply_optimizer_in_backward as apply_optimizer_in_backward,
)

# Option 2: Applying optimizer through apply_optimizer_in_backward
# Note: we need to call apply_optimizer_in_backward on unsharded model first and then shard it

# We can achieve the same result as we did in the previous
ebc_apply_opt = copy.deepcopy(ebc)
optimizer_kwargs = {"lr": 0.5}

for name, param in ebc_apply_opt.named_parameters():
    print(f"{name=}")
    apply_optimizer_in_backward(torch.optim.SGD, [param], optimizer_kwargs)

sharded_ebc_apply_opt = sharder.shard(
    ebc_apply_opt, plan.plan[""], env, torch.device("cuda")
)

# Now when we print the optimizer, we will see our new learning rate, you can verify momentum through the TBE logs as well if outputted
print(sharded_ebc_apply_opt.fused_optimizer)
print(type(sharded_ebc_apply_opt.fused_optimizer))

# We can also check through the filter other parameters that aren't associated with the "fused" optimizer(s)
# Practically, just non TorchRec module parameters. Since our module is just a TorchRec EBC
# there are no other parameters that aren't associated with TorchRec
print("Non Fused Model Parameters:")
print(
    dict(
        in_backward_optimizer_filter(sharded_ebc_fused_params.named_parameters())
    ).keys()
)

# Here we do a dummy backwards call and see that parameter updates for fused
# optimizers happen as a result of the backward pass

ebc_output = sharded_ebc_fused_params(kjt).wait().values()
loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
print(f"First Iteration Loss: {loss}")

loss.backward()

ebc_output = sharded_ebc_fused_params(kjt).wait().values()
loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
# We don't call an optimizer.step(), so for the loss to have changed here,
# that means that the gradients were somehow updated, which is what the
# fused optimizer automatically handles for us
print(f"Second Iteration Loss: {loss}")

結論#

在本教程中,您已經完成了分散式推薦系統模型的訓練。如果您對推理感興趣,TorchRec 倉庫 包含一個關於如何在推理模式下執行 TorchRec 的完整示例。

有關更多資訊,請參閱我們的 dlrm 示例,其中包括使用 用於個性化和推薦系統的深度學習推薦模型 中描述的方法在 Criteo 1TB 資料集上進行多節點訓練。