評價此頁

分散式 RPC 框架入門#

創建於:2020年01月01日 | 最後更新:2025年09月03日 | 最後驗證:2024年11月05日

作者Shen Li

注意

editgithub 上檢視和編輯此教程。

先決條件

本教程使用兩個簡單的示例來演示如何使用 PyTorch v1.4 中首次作為實驗性功能引入的 torch.distributed.rpc 包構建分散式訓練。兩個示例的原始碼可以在 PyTorch 示例 中找到。

之前的教程,分散式資料並行入門使用 PyTorch 編寫分散式應用程式,描述了 DistributedDataParallel,它支援一種特定的訓練模式,即模型在多個程序中複製,每個程序處理輸入資料的一個分割槽。有時,您可能會遇到需要不同訓練模式的場景。例如

  1. 在強化學習中,從環境中獲取訓練資料可能相對昂貴,而模型本身可能相當小。在這種情況下,可以考慮生成多個並行執行的觀察者並共享一個單獨的代理。在這種情況下,代理會本地處理訓練,但應用程式仍然需要庫來在觀察者和訓練器之間傳送和接收資料。

  2. 您的模型可能太大,無法裝入單個機器的 GPU 中,因此需要一個庫來幫助將模型拆分到多臺機器上。或者,您可能正在實現一個 引數伺服器 訓練框架,其中模型引數和訓練器位於不同的機器上。

上面的這些場景都可以透過 torch.distributed.rpc 包來幫助解決。在場景 1 中,RPCRRef 允許將資料從一個工作節點發送到另一個工作節點,同時輕鬆引用遠端資料物件。在場景 2 中,分散式自動微分分散式最佳化器 使執行反向傳播和最佳化器步驟如同本地訓練一樣。在接下來的兩個部分中,我們將透過強化學習示例和語言模型示例演示 torch.distributed.rpc 的 API。請注意,本教程的目的不是構建最準確或最高效的模型來解決給定的問題,而是展示如何使用 torch.distributed.rpc 包來構建分散式訓練應用程式。

使用 RPC 和 RRef 進行分散式強化學習#

本節介紹使用 RPC 構建玩具分散式強化學習模型以解決 OpenAI Gym 中的 CartPole-v1 的步驟。策略程式碼大部分是從現有的單執行緒 示例 中借用的,如下所示。我們將跳過 Policy 設計的細節,而是專注於 RPC 的使用。

import torch.nn as nn
import torch.nn.functional as F

class Policy(nn.Module):

    def __init__(self):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(4, 128)
        self.dropout = nn.Dropout(p=0.6)
        self.affine2 = nn.Linear(128, 2)

    def forward(self, x):
        x = self.affine1(x)
        x = self.dropout(x)
        x = F.relu(x)
        action_scores = self.affine2(x)
        return F.softmax(action_scores, dim=1)

我們準備介紹觀察者。在此示例中,每個觀察者都會建立自己的環境,並等待代理的命令來執行一個回合。在每個回合中,一個觀察者最多迴圈 n_steps 次迭代,在每次迭代中,它使用 RPC 將其環境狀態傳遞給代理並獲得一個動作。然後,它將該動作應用於其環境,並從環境中獲得獎勵和下一個狀態。之後,觀察者使用另一個 RPC 將獎勵報告給代理。同樣,請注意,這顯然不是最高效的觀察者實現。例如,一個簡單的最佳化可以將當前狀態和上一個獎勵打包到一個 RPC 中以減少通訊開銷。但是,目標是演示 RPC API,而不是構建 CartPole 的最佳求解器。所以,在這個例子中,我們保持邏輯簡單和兩個步驟明確。

import argparse
import gym
import torch.distributed.rpc as rpc

parser = argparse.ArgumentParser(
    description="RPC Reinforcement Learning Example",
    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

parser.add_argument('--world_size', default=2, type=int, metavar='W',
                    help='number of workers')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
                    help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                    help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed  for reproducibility')
args = parser.parse_args()

class Observer:

    def __init__(self):
        self.id = rpc.get_worker_info().id
        self.env = gym.make('CartPole-v1')
        self.env.seed(args.seed)

    def run_episode(self, agent_rref):
        state, ep_reward = self.env.reset(), 0
        for _ in range(10000):
            # send the state to the agent to get an action
            action = agent_rref.rpc_sync().select_action(self.id, state)

            # apply the action to the environment, and get the reward
            state, reward, done, _ = self.env.step(action)

            # report the reward to the agent for training purpose
            agent_rref.rpc_sync().report_reward(self.id, reward)

            # finishes after the number of self.env._max_episode_steps
            if done:
                break

代理的程式碼稍微複雜一些,我們將將其分解成多個部分。在此示例中,代理同時充當訓練器和主控節點,它向多個分散式觀察者傳送命令以執行回合,並且它還在本地記錄所有動作和獎勵,這些將在每個回合後的訓練階段使用。下面的程式碼顯示了 Agent 建構函式,其中大部分是初始化各種元件。最後的迴圈遠端初始化其他工作節點上的觀察者,並在本地持有這些觀察者的 RRef。代理稍後將使用這些觀察者 RRef 傳送命令。應用程式無需擔心 RRef 的生命週期。每個 RRef 的所有者維護一個引用計數對映來跟蹤其生命週期,並保證只要有任何活動的 RRef 使用者,遠端資料物件就不會被刪除。有關詳細資訊,請參閱 RRef 設計文件

import gym
import numpy as np

import torch
import torch.distributed.rpc as rpc
import torch.optim as optim
from torch.distributed.rpc import RRef, rpc_async, remote
from torch.distributions import Categorical

class Agent:
    def __init__(self, world_size):
        self.ob_rrefs = []
        self.agent_rref = RRef(self)
        self.rewards = {}
        self.saved_log_probs = {}
        self.policy = Policy()
        self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)
        self.eps = np.finfo(np.float32).eps.item()
        self.running_reward = 0
        self.reward_threshold = gym.make('CartPole-v1').spec.reward_threshold
        for ob_rank in range(1, world_size):
            ob_info = rpc.get_worker_info(OBSERVER_NAME.format(ob_rank))
            self.ob_rrefs.append(remote(ob_info, Observer))
            self.rewards[ob_info.id] = []
            self.saved_log_probs[ob_info.id] = []

接下來,代理向觀察者公開兩個 API,用於選擇動作和報告獎勵。這些函式僅在代理本地執行,但將透過 RPC 由觀察者觸發。

class Agent:
    ...
    def select_action(self, ob_id, state):
        state = torch.from_numpy(state).float().unsqueeze(0)
        probs = self.policy(state)
        m = Categorical(probs)
        action = m.sample()
        self.saved_log_probs[ob_id].append(m.log_prob(action))
        return action.item()

    def report_reward(self, ob_id, reward):
        self.rewards[ob_id].append(reward)

讓我們在代理上新增一個 run_episode 函式,該函式指示所有觀察者執行一個回合。在此函式中,它首先建立一個列表來收集非同步 RPC 的期貨,然後遍歷所有觀察者 RRef 以進行非同步 RPC。在這些 RPC 中,代理還將自身的 RRef 傳遞給觀察者,以便觀察者也可以呼叫代理上的函式。如上所示,每個觀察者將透過 RPC 回撥代理,這是巢狀 RPC。每個回合後,saved_log_probsrewards 將包含記錄的動作機率和獎勵。

class Agent:
    ...
    def run_episode(self):
        futs = []
        for ob_rref in self.ob_rrefs:
            # make async RPC to kick off an episode on all observers
            futs.append(
                rpc_async(
                    ob_rref.owner(),
                    ob_rref.rpc_sync().run_episode,
                    args=(self.agent_rref,)
                )
            )

        # wait until all obervers have finished this episode
        for fut in futs:
            fut.wait()

最後,在一個回合之後,代理需要訓練模型,這在下面的 finish_episode 函式中實現。此函式中沒有 RPC,並且大部分是從單執行緒 示例 中借用的。因此,我們跳過描述其內容。

class Agent:
    ...
    def finish_episode(self):
      # joins probs and rewards from different observers into lists
      R, probs, rewards = 0, [], []
      for ob_id in self.rewards:
          probs.extend(self.saved_log_probs[ob_id])
          rewards.extend(self.rewards[ob_id])

      # use the minimum observer reward to calculate the running reward
      min_reward = min([sum(self.rewards[ob_id]) for ob_id in self.rewards])
      self.running_reward = 0.05 * min_reward + (1 - 0.05) * self.running_reward

      # clear saved probs and rewards
      for ob_id in self.rewards:
          self.rewards[ob_id] = []
          self.saved_log_probs[ob_id] = []

      policy_loss, returns = [], []
      for r in rewards[::-1]:
          R = r + args.gamma * R
          returns.insert(0, R)
      returns = torch.tensor(returns)
      returns = (returns - returns.mean()) / (returns.std() + self.eps)
      for log_prob, R in zip(probs, returns):
          policy_loss.append(-log_prob * R)
      self.optimizer.zero_grad()
      policy_loss = torch.cat(policy_loss).sum()
      policy_loss.backward()
      self.optimizer.step()
      return min_reward

有了 PolicyObserverAgent 類,我們就可以啟動多個程序來進行分散式訓練了。在此示例中,所有程序都執行相同的 run_worker 函式,它們使用 rank 來區分其角色。Rank 0 始終是代理,所有其他 rank 都是觀察者。代理透過反覆呼叫 run_episodefinish_episode 直到執行獎勵超過環境指定的獎勵閾值來充當主控節點。所有觀察者被動地等待來自代理的命令。程式碼由 rpc.init_rpcrpc.shutdown 包裝,它們分別初始化和終止 RPC 例項。更多詳細資訊可在 API 頁面 中找到。

import os
from itertools import count

import torch.multiprocessing as mp

AGENT_NAME = "agent"
OBSERVER_NAME="obs{}"

def run_worker(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    if rank == 0:
        # rank0 is the agent
        rpc.init_rpc(AGENT_NAME, rank=rank, world_size=world_size)

        agent = Agent(world_size)
        print(f"This will run until reward threshold of {agent.reward_threshold}"
                " is reached. Ctrl+C to exit.")
        for i_episode in count(1):
            agent.run_episode()
            last_reward = agent.finish_episode()

            if i_episode % args.log_interval == 0:
                print(f"Episode {i_episode}\tLast reward: {last_reward:.2f}\tAverage reward: "
                    f"{agent.running_reward:.2f}")
            if agent.running_reward > agent.reward_threshold:
                print(f"Solved! Running reward is now {agent.running_reward}!")
                break
    else:
        # other ranks are the observer
        rpc.init_rpc(OBSERVER_NAME.format(rank), rank=rank, world_size=world_size)
        # observers passively waiting for instructions from the agent

    # block until all rpcs finish, and shutdown the RPC instance
    rpc.shutdown()


mp.spawn(
    run_worker,
    args=(args.world_size, ),
    nprocs=args.world_size,
    join=True
)

以下是使用 world_size=2 進行訓練時的一些示例輸出。

This will run until reward threshold of 475.0 is reached. Ctrl+C to exit.
Episode 10      Last reward: 26.00      Average reward: 10.01
Episode 20      Last reward: 16.00      Average reward: 11.27
Episode 30      Last reward: 49.00      Average reward: 18.62
Episode 40      Last reward: 45.00      Average reward: 26.09
Episode 50      Last reward: 44.00      Average reward: 30.03
Episode 60      Last reward: 111.00     Average reward: 42.23
Episode 70      Last reward: 131.00     Average reward: 70.11
Episode 80      Last reward: 87.00      Average reward: 76.51
Episode 90      Last reward: 86.00      Average reward: 95.93
Episode 100     Last reward: 13.00      Average reward: 123.93
Episode 110     Last reward: 33.00      Average reward: 91.39
Episode 120     Last reward: 73.00      Average reward: 76.38
Episode 130     Last reward: 137.00     Average reward: 88.08
Episode 140     Last reward: 89.00      Average reward: 104.96
Episode 150     Last reward: 97.00      Average reward: 98.74
Episode 160     Last reward: 150.00     Average reward: 100.87
Episode 170     Last reward: 126.00     Average reward: 104.38
Episode 180     Last reward: 500.00     Average reward: 213.74
Episode 190     Last reward: 322.00     Average reward: 300.22
Episode 200     Last reward: 165.00     Average reward: 272.71
Episode 210     Last reward: 168.00     Average reward: 233.11
Episode 220     Last reward: 184.00     Average reward: 195.02
Episode 230     Last reward: 284.00     Average reward: 208.32
Episode 240     Last reward: 395.00     Average reward: 247.37
Episode 250     Last reward: 500.00     Average reward: 335.42
Episode 260     Last reward: 500.00     Average reward: 386.30
Episode 270     Last reward: 500.00     Average reward: 405.29
Episode 280     Last reward: 500.00     Average reward: 443.29
Episode 290     Last reward: 500.00     Average reward: 464.65
Solved! Running reward is now 475.3163778435275!

在此示例中,我們展示瞭如何使用 RPC 作為通訊工具來跨工作節點傳遞資料,以及如何使用 RRef 來引用遠端物件。確實,您可以直接在 ProcessGroup sendrecv API 之上構建整個結構,或者使用其他通訊/RPC 庫。但是,透過使用 torch.distributed.rpc,您可以獲得底層原生的支援和持續最佳化的效能。

接下來,我們將展示如何將 RPC 和 RRef 與分散式自動微分和分散式最佳化器結合起來,以執行分散式模型並行訓練。

使用分散式自動微分和分散式最佳化器進行分散式 RNN#

在本節中,我們使用一個 RNN 模型來展示如何使用 RPC API 構建分散式模型並行訓練。示例 RNN 模型非常小,可以輕鬆裝入單個 GPU,但我們仍將其層拆分到兩個不同的工作節點上以演示這個想法。開發者可以將類似的技術應用於跨多個裝置和機器分發更大的模型。

RNN 模型設計借鑑自 PyTorch 示例 倉庫中的詞語言模型,該模型包含三個主要元件:一個嵌入表、一個 LSTM 層和一個解碼器。下面的程式碼將嵌入表和解碼器包裝到子模組中,以便它們的建構函式可以傳遞給 RPC API。在 EmbeddingTable 子模組中,我們故意將 Embedding 層放在 GPU 上以涵蓋用例。在 v1.4 中,RPC 始終在目標工作節點上建立 CPU 張量引數或返回值。如果函式接受 GPU 張量,則需要顯式將其移動到正確的裝置。

class EmbeddingTable(nn.Module):
    r"""
    Encoding layers of the RNNModel
    """
    def __init__(self, ntoken, ninp, dropout):
        super(EmbeddingTable, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp).cuda()
        self.encoder.weight.data.uniform_(-0.1, 0.1)

    def forward(self, input):
        return self.drop(self.encoder(input.cuda()).cpu()


class Decoder(nn.Module):
    def __init__(self, ntoken, nhid, dropout):
        super(Decoder, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.decoder = nn.Linear(nhid, ntoken)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-0.1, 0.1)

    def forward(self, output):
        return self.decoder(self.drop(output))

有了上述子模組,我們現在就可以使用 RPC 將它們組合起來建立一個 RNN 模型。在下面的程式碼中,ps 代表一個引數伺服器,它託管嵌入表和解碼器的引數。建構函式使用 remote API 在引數伺服器上建立一個 EmbeddingTable 物件和一個 Decoder 物件,並在本地建立 LSTM 子模組。在前向傳播過程中,訓練器使用 EmbeddingTable RRef 來查詢遠端子模組,並透過 RPC 將輸入資料傳遞給 EmbeddingTable 並獲取查詢結果。然後,它將嵌入透過本地 LSTM 層執行,最後使用另一個 RPC 將輸出傳送到 Decoder 子模組。總的來說,要實現分散式模型並行訓練,開發者可以將模型劃分為子模組,呼叫 RPC 遠端建立子模組例項,並在需要時使用 RRef 來查詢它們。正如您在下面的程式碼中看到的,它與單機模型並行訓練非常相似。主要區別在於將 Tensor.to(device) 替換為 RPC 函式。

class RNNModel(nn.Module):
    def __init__(self, ps, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(RNNModel, self).__init__()

        # setup embedding table remotely
        self.emb_table_rref = rpc.remote(ps, EmbeddingTable, args=(ntoken, ninp, dropout))
        # setup LSTM locally
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        # setup decoder remotely
        self.decoder_rref = rpc.remote(ps, Decoder, args=(ntoken, nhid, dropout))

    def forward(self, input, hidden):
        # pass input to the remote embedding table and fetch emb tensor back
        emb = _remote_method(EmbeddingTable.forward, self.emb_table_rref, input)
        output, hidden = self.rnn(emb, hidden)
        # pass output to the rremote decoder and get the decoded output back
        decoded = _remote_method(Decoder.forward, self.decoder_rref, output)
        return decoded, hidden

在介紹分散式最佳化器之前,讓我們新增一個輔助函式來生成模型引數 RRef 列表,該列表將被分散式最佳化器消耗。在本地訓練中,應用程式可以呼叫 Module.parameters() 來獲取所有引數張量的引用,並將其傳遞給本地最佳化器以進行後續更新。但是,相同的 API 在分散式訓練場景中不起作用,因為一些引數位於遠端機器上。因此,分散式最佳化器不是採用張量列表,而是採用 RRef 列表,每個模型引數(本地和遠端)都有一個 RRef。輔助函式非常簡單,只需為每個引數呼叫 Module.parameters() 並建立一個本地 RRef

def _parameter_rrefs(module):
    param_rrefs = []
    for param in module.parameters():
        param_rrefs.append(RRef(param))
    return param_rrefs

然後,由於 RNNModel 包含三個子模組,我們需要呼叫 _parameter_rrefs 三次,並將其包裝到另一個輔助函式中。

class RNNModel(nn.Module):
    ...
    def parameter_rrefs(self):
        remote_params = []
        # get RRefs of embedding table
        remote_params.extend(_remote_method(_parameter_rrefs, self.emb_table_rref))
        # create RRefs for local parameters
        remote_params.extend(_parameter_rrefs(self.rnn))
        # get RRefs of decoder
        remote_params.extend(_remote_method(_parameter_rrefs, self.decoder_rref))
        return remote_params

現在,我們可以實現訓練迴圈了。在初始化模型引數後,我們建立 RNNModelDistributedOptimizer。分散式最佳化器將採用 RRef 引數列表,查詢所有不同的所有者工作節點,並在每個所有者工作節點上使用給定的引數(即 lr=0.05)建立給定的本地最佳化器(例如 SGD,您也可以使用其他本地最佳化器)。

在訓練迴圈中,它首先建立一個分散式自動微分上下文,這將幫助分散式自動微分引擎查詢梯度和涉及的 RPC 傳送/接收函式。分散式自動微分引擎的設計細節可以在其 設計說明 中找到。然後,它像本地模型一樣啟動前向傳播,並執行分散式反向傳播。對於分散式反向傳播,您只需要指定一個根列表,在本例中,它是損失 Tensor。分散式自動微分引擎將自動遍歷分散式圖並正確寫入梯度。接下來,它執行分散式最佳化器的 step 函式,該函式將聯絡所有涉及的本地最佳化器來更新模型引數。與本地訓練相比,一個細微的區別是您不需要執行 zero_grad(),因為每個自動微分上下文都有專門的空間來儲存梯度,並且由於我們每個迭代建立一個上下文,因此來自不同迭代的梯度不會累積到同一組 Tensors 中。

def run_trainer():
    batch = 5
    ntoken = 10
    ninp = 2

    nhid = 3
    nindices = 3
    nlayers = 4
    hidden = (
        torch.randn(nlayers, nindices, nhid),
        torch.randn(nlayers, nindices, nhid)
    )

    model = rnn.RNNModel('ps', ntoken, ninp, nhid, nlayers)

    # setup distributed optimizer
    opt = DistributedOptimizer(
        optim.SGD,
        model.parameter_rrefs(),
        lr=0.05,
    )

    criterion = torch.nn.CrossEntropyLoss()

    def get_next_batch():
        for _ in range(5):
            data = torch.LongTensor(batch, nindices) % ntoken
            target = torch.LongTensor(batch, ntoken) % nindices
            yield data, target

    # train for 10 iterations
    for epoch in range(10):
        for data, target in get_next_batch():
            # create distributed autograd context
            with dist_autograd.context() as context_id:
                hidden[0].detach_()
                hidden[1].detach_()
                output, hidden = model(data, hidden)
                loss = criterion(output, target)
                # run distributed backward pass
                dist_autograd.backward(context_id, [loss])
                # run distributed optimizer
                opt.step(context_id)
                # not necessary to zero grads since they are
                # accumulated into the distributed autograd context
                # which is reset every iteration.
        print("Training epoch {}".format(epoch))

最後,讓我們新增一些粘合程式碼來啟動引數伺服器和訓練器程序。

def run_worker(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    if rank == 1:
        rpc.init_rpc("trainer", rank=rank, world_size=world_size)
        _run_trainer()
    else:
        rpc.init_rpc("ps", rank=rank, world_size=world_size)
        # parameter server do nothing
        pass

    # block until all rpcs finish
    rpc.shutdown()


if __name__=="__main__":
    world_size = 2
    mp.spawn(run_worker, args=(world_size, ), nprocs=world_size, join=True)