評價此頁

torch.distributed.tensor#

創建於:2025年6月13日 | 最後更新於:2025年8月23日

注意

torch.distributed.tensor 目前處於 Alpha 階段,正在開發中。我們承諾大部分文件中列出的 API 向後相容,但如果需要,可能會進行 API 更改。

PyTorch DTensor (分散式張量)#

PyTorch DTensor 提供了簡單靈活的張量分片原語,可透明地處理分散式邏輯,包括跨裝置/主機的分片儲存、運算元計算和集合通訊。在處理多維分片時,DTensor 可用於構建不同的並行解決方案並支援分片 state_dict 表示。

請參閱基於 DTensor 構建的 PyTorch 原生並行解決方案的示例。

DTensor 遵循 SPMD (單程式多資料) 程式設計模型,使使用者能夠像編寫 **單裝置程式一樣編寫分散式程式,並具有相同的收斂特性**。它透過指定 DeviceMeshPlacement 來提供統一的張量分片佈局 (DTensor Layout)。

  • DeviceMesh 使用 n 維陣列表示叢集的裝置拓撲和通訊器。

  • Placement 描述了邏輯張量在 DeviceMesh 上的分片佈局。DTensor 支援三種類型的 placement:ShardReplicatePartial

DTensor 類 API#

DTensortorch.Tensor 的子類。這意味著一旦建立了 DTensor,它就可以像 torch.Tensor 一樣使用,包括執行各種 PyTorch 運算元,就像在單裝置上執行一樣,從而為 PyTorch 運算元提供適當的分散式計算。

除了現有的 torch.Tensor 方法外,它還提供了一組額外的方法來與 torch.Tensor 互動,對 DTensor 進行 redistribute(重新分佈)佈局,獲取所有裝置上的完整張量內容等。

class torch.distributed.tensor.DTensor(local_tensor, spec, *, requires_grad)#

DTensor (分散式張量) 是 torch.Tensor 的子類,它提供了類似單裝置的抽象,用於對多裝置 torch.Tensor 進行程式設計。它透過 DeviceMesh 和以下型別的 Placement 來描述分散式張量的分片佈局 (DTensor Layout)。

  • Shard:張量在 DeviceMeshdim 維度上,按照張量的 dim 維度進行分片。

  • Replicate:張量在 DeviceMesh 維度上的裝置上進行復制。

  • Partial:張量在 DeviceMesh 維度上的裝置上等待歸約。

呼叫 PyTorch 運算元時,DTensor 會重寫 PyTorch 運算元,以便在必要時執行分片計算併發出通訊。除了運算元計算,DTensor 還會根據運算元語義本身正確地轉換或傳播 placement(DTensor Layout),並生成新的 DTensor 輸出。

為了確保呼叫 PyTorch 運算元時 DTensor 分片計算的數值正確性,DTensor 要求運算元的每個張量引數都必須是 DTensor。

注意

直接使用 Tensor 子類建構函式不是建立 DTensor 的推薦方式(即它不能正確處理 autograd,因此不是公共 API)。請參閱 create_dtensor 部分,瞭解如何建立 DTensor

返回型別

DTensor

__create_chunk_list__()[source]#

返回一個 ChunkStorageMetadata 列表,這是一個描述當前 rank 上本地分片/副本大小/偏移量的 dataclass。對於 DTensor,每個 rank 將擁有一個本地分片/副本,因此返回的列表通常只有一個元素。

這個 dunder 方法主要用於分散式 checkpoint。

返回

一個 List[ChunkStorageMetadata] 物件,表示當前 rank 上的分片大小/偏移量。

static from_local(local_tensor, device_mesh=None, placements=None, *, run_check=False, shape=None, stride=None)[source]#

根據指定的 device_meshplacements,從每個 rank 上的本地 torch.Tensor 建立一個 DTensor

引數
  • local_tensor (torch.Tensor) – 每個 rank 上的本地 torch.Tensor。

  • device_mesh (DeviceMesh, optional) – 用於放置張量的 DeviceMesh,如果未指定,則必須在 DeviceMesh 上下文管理器下呼叫,預設值:None

  • placements (List[Placement], optional) – 描述如何將本地 torch.Tensor 放置在 DeviceMesh 上的 placement,必須與 device_mesh.ndim 的元素數量相同。

關鍵字引數
  • run_check (bool, optional) – 以額外的通訊為代價,跨 rank 執行健全性檢查,以檢查每個本地張量的元資訊以確保正確性。如果在 placements 中有 Replicate,則 device mesh 維度的第一個 rank 上的資料將被廣播到其他 rank。預設值:False

  • shape (torch.Size, optional) – 一個指定 DTensor 大小的整數列表,該 DTensor 構建在 local_tensor 之上。注意,如果 local_tensor 的大小在 rank 之間不同,則需要提供此引數。如果未提供,則假設給定的分散式張量在 rank 之間均勻分片,從而計算出 shape。預設值:None

  • stride (tuple, optional) – 一個指定 DTensor stride 的整數列表。如果未提供,則假設給定的分散式張量在 rank 之間均勻分片,從而計算出 stride。預設值:None

返回

一個 DTensor 物件。

返回型別

DTensor

注意

run_check=False 時,使用者有責任確保傳入的本地張量在 rank 之間是正確的(即,對於 Shard(dim) placement,張量是分片的;對於 Replicate() placement,張量是複製的)。否則,建立的 DTensor 的行為是未定義的。

注意

from_local 是可微分的,建立的 DTensor 物件上的 requires_grad 將取決於 local_tensor 是否 requires_grad。

full_tensor(*, grad_placements=None)[source]#

返回此 DTensor 的完整張量。它將執行必要的集合操作,以收集其 DeviceMesh 上的其他 rank 的本地張量並將它們連線起來。這是以下程式碼的語法糖:

dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()

關鍵字引數

grad_placements (List[Placement], optional) – 描述從此函式返回的完整張量的任何梯度佈局的未來佈局的 placement。 full_tensor 將 DTensor 轉換為一個完整的 torch.Tensor,並且返回的 torch.tensor 在程式碼的後續部分可能無法用作原始複製的 DTensor 佈局。此引數是使用者可以提供給 autograd 的提示,以防返回張量的梯度佈局與原始複製的 DTensor 佈局不匹配。如果未指定,我們將假設完整張量的梯度佈局是複製的。

返回

一個表示此 DTensor 完整張量的 torch.Tensor 物件。

返回型別

張量

注意

full_tensor 是可微分的。

redistribute(device_mesh=None, placements=None, *, async_op=False, forward_dtype=None, backward_dtype=None)[source]#

redistribute 執行必要的集體操作,將當前 DTensor 從其當前 placement 重新分佈到新的 placement,或從其當前 DeviceMesh 重新分佈到新的 DeviceMesh。也就是說,透過為 DeviceMesh 的每個維度指定 Replicate placement,我們可以將 Sharded DTensor 轉換為 Replicated DTensor。

當在單個裝置 mesh 維度上從當前 placement 重新分佈到新的 placement 時,我們將執行以下操作,包括通訊集體或本地操作:

  1. Shard(dim) -> Replicate(): all_gather

  2. Shard(src_dim) -> Shard(dst_dim): all_to_all

  3. Replicate() -> Shard(dim): 本地分塊(即 torch.chunk

  4. Partial() -> Replicate(): all_reduce

  5. Partial() -> Shard(dim): reduce_scatter

redistribute 會正確地為在 1D 或 N-D DeviceMesh 上建立的 DTensor 確定必要的重新分佈步驟。

引數
  • device_mesh (DeviceMesh, optional) – 用於放置 DTensor 的 DeviceMesh。如果未指定,則使用當前 DTensor 的 DeviceMesh。預設值:None

  • placements (List[Placement], optional) – 描述如何將 DTensor 放置到 DeviceMesh 中的新 placement,必須與 device_mesh.ndim 的元素數量相同。預設值:在所有 mesh 維度上覆制。

關鍵字引數
  • async_op (bool, optional) – 是否非同步執行 DTensor 重新分佈操作。預設值:False

  • forward_dtype (torch.dtype, optional) – 在其 forward 傳播中重新分佈本地張量之前,可以將本地張量的資料型別轉換為 forward_dtype。結果 DTensor 的資料型別將為 forward_dtype。預設值:None。

  • backward_dtype (torch.dtype, optional) – 在其 backward 傳播中重新分佈本地張量之前,可以將本地張量的資料型別轉換為 backward_dtype。結果 DTensor 的梯度將被轉換回當前 DTensor 的資料型別。預設值:None

返回

一個 DTensor 物件。

返回型別

DTensor

注意

redistribute 是可微分的,這意味著使用者無需擔心 redistribute 操作的 backward 公式。

注意

redistribute 目前僅支援在相同的 DeviceMesh 上重新分佈 DTensor。如果您需要將 DTensor 重新分佈到不同的 DeviceMesh,請提交 issue。

to_local(*, grad_placements=None)[source]#

獲取 DTensor 在其當前 rank 上的本地張量。對於分片,它返回邏輯張量檢視的本地分片;對於複製,它返回當前 rank 上的副本。

關鍵字引數

grad_placements (List[Placement], optional) – 描述從此函式返回的張量的任何梯度佈局的未來佈局的 placement。 to_local 將 DTensor 轉換為本地張量,並且返回的本地張量在程式碼的後續部分可能無法用作原始 DTensor 佈局。此引數是使用者可以提供給 autograd 的提示,以防返回張量的梯度佈局與原始 DTensor 佈局不匹配。如果未指定,我們將假設梯度佈局與原始 DTensor 保持不變,並將其用於梯度計算。

返回

一個 torch.TensorAsyncCollectiveTensor 物件。它表示當前 rank 上的本地張量。當返回 AsyncCollectiveTensor 物件時,表示本地張量尚未就緒(即通訊尚未完成)。在這種情況下,使用者需要呼叫 wait 來等待本地張量就緒。

返回型別

張量

注意

to_local 是可微分的,返回的本地張量的 requires_grad 將取決於 DTensor 是否 requires_grad。

property device_mesh: DeviceMesh#

與此 DTensor 物件關聯的 DeviceMesh 屬性。

注意

device_mesh 是隻讀屬性,無法設定。

property placements: tuple[torch.distributed.tensor.placement_types.Placement, ...]#

此 DTensor 的 placements 屬性,描述了該 DTensor 在其 DeviceMesh 上的佈局。

注意

placements 是隻讀屬性,無法設定。

DeviceMesh 作為分散式通訊器#

DeviceMesh 是從 DTensor 構建的,作為描述叢集裝置拓撲並表示多維通訊器(基於 ProcessGroup)的抽象。有關如何建立/使用 DeviceMesh 的詳細資訊,請參閱 DeviceMesh 教程

DTensor Placement 型別#

DTensor 支援以下型別的 Placement 在每個 DeviceMesh 維度上。

class torch.distributed.tensor.placement_types.Shard(dim)[source]#

Shard(dim) placement 描述了 DTensor 在對應的 DeviceMesh 維度上,沿著張量的 dim 維度進行分片,其中 DeviceMesh 維度上的每個 rank 只持有全域性張片的一個分片/片段。 Shard(dim) placement 遵循 torch.chunk(dim) 的語義,其中最後一個分片在 DeviceMesh 維度上可能為空,當張量維度在 DeviceMesh 維度上不能被整除時。 Shard placement 可以被所有 DTensor API 使用(例如,distribute_tensor、from_local 等)。

引數

dim (int) – 張量維度,描述 DTensor 沿著其對應的 DeviceMesh 維度進行分片。

警告

沿張量維度進行分片,而張量維度大小不能被 DeviceMesh 維度整除,目前處於實驗階段,可能會發生變化。

dim: int#
class torch.distributed.tensor.placement_types.Replicate[source]#

Replicate() placement 描述了 DTensor 在對應的 DeviceMesh 維度上進行復制,其中 DeviceMesh 維度上的每個 rank 都持有全域性張量的一個副本。 Replicate placement 可以被所有 DTensor API 使用(例如,distribute_tensorDTensor.from_local 等)。

class torch.distributed.tensor.placement_types.Partial(reduce_op='sum')[source]#

Partial(reduce_op) placement 描述了 DTensor 在指定的 DeviceMesh 維度上等待歸約,其中 DeviceMesh 維度上的每個 rank 持有全域性張量的部分值。使用者可以使用 redistributePartial DTensor 重新分佈到指定 DeviceMesh 維度上的 ReplicateShard(dim) placement,這將觸發底層必要的通訊操作(即 allreducereduce_scatter)。

引數

reduce_op (str, optional) – 用於部分 DTensor 生成 Replicated/Sharded DTensor 的歸約 op。僅支援逐元素歸約操作,包括:“sum”、“avg”、“product”、“max”、“min”,預設值:“sum”。

注意

Partial placement 可以作為 DTensor 運算元的結果生成,並且只能由 DTensor.from_local API 使用。

reduce_op: str = 'sum'#
class torch.distributed.tensor.placement_types.Placement[source]#

Placement 型別的基類,它描述了 DTensor 如何放置在 DeviceMesh 上。 PlacementDeviceMesh 一起可以描述 DTensor Layout。它是三種主要 DTensor Placement 型別:ShardReplicatePartial 的基類。

此類不應直接使用,主要作為型別存根。

is_partial(reduce_op=None)[source]#
返回型別

布林值

is_replicate()[source]#
返回型別

布林值

is_shard(dim=None)[source]#
返回型別

布林值

建立 DTensor 的不同方式#

有三種方法可以構造一個 DTensor
  • distribute_tensor() 從每個 rank 上的邏輯或“全域性” torch.Tensor 建立一個 DTensor。這可用於分片葉子 torch.Tensor(例如,模型引數/緩衝區和輸入)。

  • DTensor.from_local() 從每個 rank 上的本地 torch.Tensor 建立一個 DTensor,可用於從非葉子 torch.Tensor(例如,forward/backward 期間的中間啟用張量)建立 DTensor

  • DTensor 提供了專用的張量工廠函式(例如 empty()ones()randn() 等),允許透過直接指定 DeviceMeshPlacement 來進行不同的 DTensor 建立。與 distribute_tensor() 相比,這可以直接在裝置上實現分片記憶體,而不是在初始化邏輯張量記憶體後進行分片。

從邏輯 torch.Tensor 建立 DTensor#

torch.distributed 中的 SPMD (單程式多資料) 程式設計模型透過 (例如 torchrun) 啟動多個程序來執行相同的程式,這意味著程式中的模型將首先在不同的程序上初始化(即模型可能初始化在 CPU、meta device,或者直接在 GPU 上,如果記憶體足夠)。

DTensor 提供了一個 distribute_tensor() API,它可以分片模型權重或張量到 DTensor,從而使建立的 DTensor 符合單裝置語義,這對於 **數值正確性** 至關重要。

torch.distributed.tensor.distribute_tensor(tensor, device_mesh=None, placements=None, *, src_data_rank=0)[source]#

根據指定的 placements,將一個葉子 torch.Tensor(即 nn.Parameter/buffers)分發到 device_meshdevice_mesh 的 rank 和 placements 的數量必須相同。要分發的 tensor 是邏輯/全域性張量,API 將使用 DeviceMesh 第一個 rank 上的 tensor 作為事實來源以保留單裝置語義。如果您想在 Autograd 計算的中間構建 DTensor,請改用 DTensor.from_local()

引數
  • tensor (torch.Tensor) – 要分發的 torch.Tensor。請注意,如果您想在裝置 mesh 維度的裝置數量不能整除的維度上分片張量,我們將使用 torch.chunk 語義來分片張量並分散分片。不均勻分片行為是實驗性的,可能會發生變化。

  • device_mesh (DeviceMesh, optional) – 用於分發張量的 DeviceMesh,如果未指定,則必須在 DeviceMesh 上下文管理器下呼叫,預設值:None

  • placements (List[Placement], optional) – 描述如何將張量放置在 DeviceMesh 上的 placement,必須與 device_mesh.ndim 的元素數量相同。如果未指定,我們將預設將張量從 device_mesh 的每個維度的第一個 rank 複製到該 device_mesh

關鍵字引數

src_data_rank (int, optional) – 邏輯/全域性張量源資料的 rank,distribute_tensor() 使用它來將分片/副本分散/廣播到其他 rank。預設情況下,我們在每個 DeviceMesh 維度的 group_rank=0 作為源資料,以保留單裝置語義。如果顯式傳遞 Nonedistribute_tensor() 將直接使用其本地資料,而不是嘗試透過 scatter/broadcast 來保留單裝置語義。預設值:0

返回

一個每個 rank 上的 DTensorXLAShardedTensor 物件。

返回型別

DTensor

注意

當使用 xla device_type 初始化 DeviceMesh 時,distribute_tensor 返回 XLAShardedTensor。有關更多詳細資訊,請參閱 此 issue。XLA 整合處於實驗階段,可能會發生變化。

除了 distribute_tensor(),DTensor 還提供了一個 distribute_module() API,以便更容易地在 nn.Module 層面進行分片。

torch.distributed.tensor.distribute_module(module, device_mesh=None, partition_fn=None, input_fn=None, output_fn=None)[source]#

此函式公開三個函式來控制模組的引數/輸入/輸出。

1. 透過指定 partition_fn (例如,允許使用者根據指定的 partition_fn 將 Module 引數轉換為 DTensor 引數) 來在執行時執行之前對模組進行分片。2. 透過指定 input_fnoutput_fn 來控制模組在執行時期間的輸入或輸出。(例如,將輸入轉換為 DTensor,將輸出轉換回 torch.Tensor)。

引數
  • module (nn.Module) – 使用者要分割槽的模組。

  • device_mesh (DeviceMesh) – 用於放置模組的裝置 mesh。

  • partition_fn (Callable) – 用於分割槽引數的函式(例如,在 device_mesh 上分片某些引數)。如果未指定 partition_fn,則預設情況下我們將 module 的所有模組引數複製到 mesh 上。

  • input_fn (Callable) – 指定輸入分佈,例如,可以控制模組的輸入如何分片。input_fn 將作為模組的 forward_pre_hook (forward 前置鉤子) 安裝。

  • output_fn (Callable) – 指定輸出分佈,例如,可以控制輸出如何分片,或將其轉換回 torch.Tensor。output_fn 將作為模組的 forward_hook (forward 後置鉤子) 安裝。

返回

一個包含所有 DTensor s 引數/緩衝區的模組。

返回型別

模組

注意

當使用 xla device_type 初始化 DeviceMesh 時,distribute_module 返回帶有 PyTorch/XLA SPMD 註釋的引數的 nn.Module。有關更多詳細資訊,請參閱 此 issue。XLA 整合處於實驗階段,可能會發生變化。

DTensor 工廠函式#

DTensor 還提供了專用的張量工廠函式,允許使用類似 torch.Tensor 的工廠函式 API(例如 torch.ones, torch.empty, 等)直接建立 DTensor,此外還可以為建立的 DTensor 指定 DeviceMeshPlacement

torch.distributed.tensor.zeros(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]#

返回一個填充了標量值 0 的 DTensor

引數

size (int...) – 定義輸出 DTensor 形狀的整數序列。可以是可變數量的引數或列表或元組等集合。例如:zeros(1,2,3..) 或 zeros([1,2,3..]) 或 zeros((1,2,3..))

關鍵字引數
  • requires_grad (bool, optional) – 如果 autograd 應該記錄返回的 DTensor 上的操作。預設值:False

  • dtype (torch.dtype, optional) – 所需返回 DTensor 的資料型別。預設值:如果 None,則使用全域性預設值(請參閱 torch.set_default_dtype())。

  • layout (torch.layout, optional) – 所需返回 DTensor 的佈局。預設值:torch.strided

  • device_meshDeviceMesh 型別,包含 rank 的 mesh 資訊。

  • placements – 一個 Placement 型別的序列:ShardReplicate

返回

每個 rank 上的一個 DTensor 物件。

返回型別

DTensor

torch.distributed.tensor.ones(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]#

返回一個填充了標量值 1 的 DTensor,其形狀由可變引數 size 定義。

引數

size (int...) – 定義輸出 DTensor 形狀的整數序列。可以是可變數量的引數或列表或元組等集合。例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))

關鍵字引數
  • dtype (torch.dtype, optional) – 所需返回 DTensor 的資料型別。預設值:如果 None,則使用全域性預設值(請參閱 torch.set_default_dtype())。

  • layout (torch.layout, optional) – 所需返回 DTensor 的佈局。預設值:torch.strided

  • requires_grad (bool, optional) – 如果 autograd 應該記錄返回的 DTensor 上的操作。預設值:False

  • device_meshDeviceMesh 型別,包含 rank 的 mesh 資訊。

  • placements – 一個 Placement 型別的序列:ShardReplicate

返回

每個 rank 上的一個 DTensor 物件。

返回型別

DTensor

torch.distributed.tensor.empty(*size, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]#

返回一個填充了未初始化資料的 DTensor。張量的形狀由可變引數 size 定義。

引數

size (int...) – 定義輸出 DTensor 形狀的整數序列。可以是可變數量的引數或列表或元組等集合。例如:empty(1,2,3..) 或 empty([1,2,3..]) 或 empty((1,2,3..))

關鍵字引數
  • dtype (torch.dtype, optional) – 所需返回 DTensor 的資料型別。預設值:如果 None,則使用全域性預設值(請參閱 torch.set_default_dtype())。 layout (torch.layout, optional): 所需返回 DTensor 的佈局。預設值:torch.strided

  • requires_grad (bool, optional) – 如果 autograd 應該記錄返回的 DTensor 上的操作。預設值:False

  • device_meshDeviceMesh 型別,包含 rank 的 mesh 資訊。

  • placements – 一個 Placement 型別的序列:ShardReplicate

返回

每個 rank 上的一個 DTensor 物件。

返回型別

DTensor

torch.distributed.tensor.full(size, fill_value, *, dtype=None, layout=torch.strided, requires_grad=False, device_mesh=None, placements=None)[source]#

根據 device_meshplacements,使用 fill_value 填充,形狀由引數 size 定義,返回一個 DTensor

引數
  • size (int...) – 定義輸出 DTensor 形狀的整數序列。可以是可變數量的引數或列表或元組等集合。例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))

  • fill_value (Scalar) – 用於填充輸出張量的值。

關鍵字引數
  • dtype (torch.dtype, optional) – 所需返回 DTensor 的資料型別。預設值:如果 None,則使用全域性預設值(請參閱 torch.set_default_dtype())。

  • layout (torch.layout, optional) – 所需返回 DTensor 的佈局。預設值:torch.strided

  • requires_grad (bool, optional) – 如果 autograd 應該記錄返回的 DTensor 上的操作。預設值:False

  • device_meshDeviceMesh 型別,包含 rank 的 mesh 資訊。

  • placements – 一個 Placement 型別的序列:ShardReplicate

返回

每個 rank 上的一個 DTensor 物件。

返回型別

DTensor

torch.distributed.tensor.rand(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[source]#

返回一個 DTensor,其中填充了 [0, 1) 區間內均勻分佈的隨機數。張量的形狀由可變引數 size 定義。

引數

size (int...) – 定義輸出 DTensor 形狀的整數序列。可以是可變數量的引數或列表或元組等集合。例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))

關鍵字引數
  • dtype (torch.dtype, optional) – 所需返回 DTensor 的資料型別。預設值:如果 None,則使用全域性預設值(請參閱 torch.set_default_dtype())。

  • layout (torch.layout, optional) – 所需返回 DTensor 的佈局。預設值:torch.strided

  • requires_grad (bool, optional) – 如果 autograd 應該記錄返回的 DTensor 上的操作。預設值:False

  • device_meshDeviceMesh 型別,包含 rank 的 mesh 資訊。

  • placements – 一個 Placement 型別的序列:ShardReplicate

返回

每個 rank 上的一個 DTensor 物件。

返回型別

DTensor

torch.distributed.tensor.randn(*size, requires_grad=False, dtype=None, layout=torch.strided, device_mesh=None, placements=None)[原始碼]#

返回一個 DTensor,其中填充了均值為 0、方差為 1 的正態分佈隨機數。張量的形狀由可變引數 size 定義。

引數

size (int...) – 定義輸出 DTensor 形狀的整數序列。可以是可變數量的引數或列表或元組等集合。例如:ones(1,2,3..) 或 ones([1,2,3..]) 或 ones((1,2,3..))

關鍵字引數
  • dtype (torch.dtype, optional) – 所需返回 DTensor 的資料型別。預設值:如果 None,則使用全域性預設值(請參閱 torch.set_default_dtype())。

  • layout (torch.layout, optional) – 所需返回 DTensor 的佈局。預設值:torch.strided

  • requires_grad (bool, optional) – 如果 autograd 應該記錄返回的 DTensor 上的操作。預設值:False

  • device_meshDeviceMesh 型別,包含 rank 的 mesh 資訊。

  • placements – 一個 Placement 型別的序列:ShardReplicate

返回

每個 rank 上的一個 DTensor 物件。

返回型別

DTensor

隨機操作#

DTensor 提供了分散式 RNG 功能,以確保分片張量上的隨機操作獲得唯一值,並且複製張量上的隨機操作獲得相同值。該系統要求所有參與的 rank(例如 SPMD rank)在執行每個 dtensor 隨機操作之前都使用相同的生成器狀態開始,如果這是真的,它確保在每個 dtensor 隨機操作完成後它們都處於相同的狀態。隨機操作期間不執行通訊來同步 RNG 狀態。

接受 generator 關鍵字引數的操作將利用使用者傳入的生成器,如果傳入了,否則使用裝置上的預設生成器。無論使用哪個生成器,在 DTensor 操作之後它都會被推進。將同一個生成器用於 DTensor 和非 DTensor 操作是有效的,但必須小心確保非 DTensor 操作在所有 rank 上平等地推進生成器狀態。

當結合使用 DTensor 和流水線並行時,每個流水線階段的 rank 應使用不同的種子,而流水線階段內的 rank 應使用相同的種子。

DTensor 的 RNG 基礎架構基於 philox 演算法,並支援任何基於 philox 的後端(cuda 和其他類似 cuda 的裝置),但不幸的是,尚不支援 CPU 後端。

除錯#

日誌記錄#

啟動程式時,可以使用 TORCH_LOGS 環境變數從 torch._logging 啟用額外的日誌記錄。

  • TORCH_LOGS=+dtensor 將顯示 logging.DEBUG 訊息及其以上所有級別。

  • TORCH_LOGS=dtensor 將顯示 logging.INFO 訊息及其以上。

  • TORCH_LOGS=-dtensor 將顯示 logging.WARNING 訊息及其以上。

除錯工具#

要除錯應用了 DTensor 的程式,並瞭解底層發生的通訊的更多細節,DTensor 提供了一個 CommDebugMode

class torch.distributed.tensor.debug.CommDebugMode#

CommDebugMode 是一個上下文管理器,用於計算其上下文中的功能性通訊次數。它透過 TorchDispatchMode 實現此目的。

注意

並非所有通訊都已支援。

使用示例

mod = ...
comm_mode = CommDebugMode()
with comm_mode:
    mod.sum().backward()
print(comm_mode.get_comm_counts())
generate_comm_debug_tracing_table(noise_level=3)[原始碼]#

生成詳細表格,顯示模組級別的操作和通訊跟蹤資訊。資訊的數量取決於 noise_level

  1. 列印模組級別的通訊計數。

  2. 列印未包含在平凡操作中的 dTensor 操作,以及模組資訊。

  3. 列印未包含在平凡操作中的操作。

  4. 列印所有操作。

generate_json_dump(file_name='comm_mode_log.json', noise_level=3)[原始碼]#

建立用於構建瀏覽器視覺化的 json 檔案。0. 列印模組級別的通訊計數;1. 列印未包含在平凡操作中的 dTensor 操作;2. 列印未包含在平凡操作中的操作;3. 列印所有操作。

get_comm_counts()[原始碼]#

以字典形式返回通訊計數。

返回

通訊計數以字典形式返回。

返回型別

Dict[Any, int]

get_parameter_info()[原始碼]#
返回型別

dict[str, dict[str, Any]]

get_sharding_info()[原始碼]#
返回型別

dict[str, dict[str, Any]]

get_total_counts()[原始碼]#
返回型別

int

log_comm_debug_tracing_table_to_file(file_name='comm_mode_log.txt', noise_level=3)[原始碼]#

與控制檯 CommDebugMode 輸出的替代方法,寫入使用者指定的檔案的內容。

為了視覺化少於 3 個維度的 DTensor 的分片,DTensor 提供了 visualize_sharding()

torch.distributed.tensor.debug.visualize_sharding(dtensor, header='', use_rich=False)[原始碼]#

在終端中視覺化 1D 或 2D DTensor 的分片。

注意

這需要 tabulate 包,或者 richmatplotlib。對於空張量,將不列印任何分片資訊。

實驗性功能#

DTensor 還提供了一系列實驗性功能。這些功能處於原型階段,或者基本功能已完成但正在徵求使用者反饋。如果您對這些功能有反饋,請在 PyTorch 上提交一個 issue。

torch.distributed.tensor.experimental.context_parallel(mesh, *, buffers=None, buffer_seq_dims=None, no_restore_buffers=None)[原始碼]#

context_parallel 是一個實驗性 API,用於啟用上下文並行 (CP)。該 API 執行兩項操作:1) 使用支援 CP 的 SDPA(torch.nn.functional.scaled_dot_product_attention)進行補丁,2) 沿序列維度分片 buffers,每個 rank 根據 mesh 保留相應的分片。

引數
  • mesh (DeviceMesh) – 用於上下文並行的裝置網格。

  • buffers (Optional[List[torch.Tensor]]) – 其使用依賴於序列維度的緩衝區。例如輸入批次、標籤和位置嵌入緩衝區。這些緩衝區必須沿序列維度分片以確保準確性。分片將就地進行,緩衝區內的形狀將在上下文中更改。緩衝區將在上下文完成後恢復。no_restore_buffers 可用於指定哪些緩衝區不需要恢復。注意 buffers 不應包含任何 nn.Parameter。

  • buffer_seq_dims (Optional[List[int]]) – buffers 的序列維度。

  • no_restore_buffers (Optional[Set[torch.Tensor]]) – 這些緩衝區集合中的緩衝區在上下文退出後不會被恢復。此集合必須是 buffers 的子集。如果退出上下文後不再使用這些緩衝區,可以將它們放入此列表中以避免額外的恢復時間。

返回型別

Generator[None, None, None]

警告

torch.distributed.tensor.experimental.context_parallel 是 PyTorch 中的一個原型功能。API 可能會發生變化。

torch.distributed.tensor.experimental.local_map(func=None, out_placements=None, in_placements=None, in_grad_placements=None, device_mesh=None, *, redistribute_inputs=False)[原始碼]#

local_map() 是一個實驗性 API,它允許使用者將 DTensor 傳遞給一個為應用於 torch.Tensor 編寫的函式。這是透過提取 DTensor 的區域性元件,呼叫函式,並根據 out_placements 將輸出包裝回 DTensor 來實現的。

引數
  • func (Callable) – 要應用於 DTensor s 的每個區域性分片的函式。

  • out_placements (Union[PlacementType, Tuple[PlacementType, …]]) – func 的展平輸出中 DTensor s 的期望放置。如果展平的 output 是單個值,則 out_placements 應為 PlacementType 型別。否則,如果展平的 output 有多個值,則 out_placements 應為 PlacementType 值的元組,與展平的 output 是一對一對映。此外,對於 Tensor 輸出,我們使用 PlacementType 作為其放置(一個 Tuple[Placement] 值)。對於非 Tensor 輸出,PlacementType 應為 None。請注意,唯一的例外是當沒有傳入 DTensor 引數時。在這種情況下,即使 out_placements 不為 None,結果函式也應忽略期望的放置,因為函式不是用 DTensor s 執行的。

  • in_placements (Tuple[PlacementType, …], optional) – func 的展平輸入中 DTensor s 的必需放置。如果指定了 in_placementslocal_map() 將檢查每個 DTensor 引數的放置是否與必需的放置相同。如果放置不相同且 redistribute_inputsFalse,則會引發異常。否則,如果 redistribute_inputsTrue,則引數將首先重新分發到必需的分片放置,然後才將其區域性張量傳遞給 func。唯一的例外是當必需的放置不為 None 且引數為 torch.Tensor 時。在這種情況下,將跳過放置檢查,並將引數直接傳遞給 func。如果 in_placementsNone,則不執行放置檢查。預設值:None

  • in_grad_placements (Tuple[PlacementType, …], optional) – 與展平輸入 DTensor 對應的 DTensor s 梯度的放置提示。此引數是使用者可以提供給 to_local() 的提示,以防區域性張量輸入的梯度佈局與其 DTensor 輸入佈局不匹配。如果未指定,我們將假定區域性張量輸入的梯度佈局與原始 DTensor 輸入保持相同,並使用該佈局進行梯度計算。預設值:None。

  • device_mesh (DeviceMesh, optional) – 輸出 DTensor s 放置在其上的裝置網格。如果未指定,將從第一個輸入 DTensor 的裝置網格推斷。預設值:None。

關鍵字引數

redistribute_inputs (bool, optional) – 布林值,指示當輸入 DTensor s 的放置與必需的輸入放置不同時,是否重新分片這些輸入 DTensor s。如果此值為 False 且某些 DTensor 輸入具有不同的放置,則會引發異常。預設值:False。

返回

一個 Callable,它將 func 應用於輸入 DTensor 的每個區域性分片,並返回一個從 func 的返回值構造的 DTensor

引發
  • AssertionError – 對於任何非 DTensor 輸出,我們要求其在 out_placements 中的相應輸出放置為 None。如果不是這種情況,將引發 AssertionError。

  • ValueError – 如果 redistribute_inputs=False 但輸入 DTensor 根據 in_placements 需要重新分發。

示例

>>> def mm_allreduce_forward(device_mesh, W, X):
>>>     partial_sum_tensor = torch.mm(W, X)
>>>     reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh)
>>>     return reduced_tensor
>>>
>>> W = torch.randn(12, 8, requires_grad=False)
>>> X = torch.randn(8, 16, requires_grad=False)
>>> Y = torch.mm(W, X)
>>> row_wise = [Shard(0)]  # row-wise sharding placements on 1-d mesh
>>> col_wise = [Shard(1)]  # col-wise sharding placements on 1-d mesh
>>>
>>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor conversion
>>> local_mm_allreduce_forward = local_map(
>>>     mm_allreduce_forward,
>>>     out_placements=[Replicate()],
>>>     in_placements=[col_wise, row_wise],
>>>     device_mesh=device_mesh,
>>> )
>>>
>>> W_dt = distribute_tensor(
...     W, device_mesh, (col_wise)
... )  # col-wisely sharded W tensor
>>> X_dt = distribute_tensor(
...     X, device_mesh, (row_wise)
... )  # row-wisely sharded X tensor
>>> Y_dt = local_mm_allreduce_forward(
...     device_mesh, W_dt, X_dt
... )  # apply local_mm_allreduce_forward to DTensors

注意

此 API 目前是實驗性的,可能會發生更改。

torch.distributed.tensor.experimental.register_sharding(op)[原始碼]#

register_sharding() 是一個實驗性 API,它允許使用者為運算子註冊分片策略,當張量輸入和輸出為 DTensor 時。當以下情況時,它可能很有用:(1) op 沒有預設的分片策略,例如當 op 是 DTensor 不支援的自定義運算子時;(2) 當用戶希望覆蓋現有運算子的預設分片策略時。

引數

op (Union[OpOverload, List[OpOverload]]) – 要註冊自定義分片函式的運算子或運算子列表。

返回

一個函式裝飾器,可用於包裝一個定義指定運算子 op 的分片策略的函式。定義的 P分片策略將註冊到 DTensor,如果 DTensor 已經實現了該運算子,則會覆蓋預設的分片策略。自定義分片函式接受與原始 op 相同的輸入(除了如果一個引數是 torch.Tensor,它將被 DTensor 內部使用的類似張量的物件替換)。該函式應返回一個 2 元組序列,每個元組指定可接受的輸出放置及其對應的輸入放置。

示例

>>> @register_sharding(aten._softmax.default)
>>> def custom_softmax_sharding(x, dim, half_to_float):
>>>     softmax_dim = dim if dim >= 0 else dim + x.ndim
>>>     acceptable_shardings = []
>>>
>>>     all_replicate = ([Replicate()], [Replicate(), None, None])
>>>     acceptable_shardings.append(all_replicate)
>>>
>>>     for sharding_dim in range(x.ndim):
>>>         if sharding_dim != softmax_dim:
>>>             all_sharded = (
>>>                 [Shard(sharding_dim)],
>>>                 [Shard(sharding_dim), None, None],
>>>             )
>>>             acceptable_shardings.append(all_sharded)
>>>
>>>     return acceptable_shardings

注意

此 API 目前是實驗性的,可能會發生更改。