FullyShardedDataParallel#
創建於: 2022年02月02日 | 最後更新於: 2025年06月11日
- class torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False, forward_prefetch=False, limit_all_gathers=True, use_orig_params=False, ignored_states=None, device_mesh=None)[source]#
一個用於在資料並行工作程序之間分片模組引數的包裝器。
這受到 Xu et al. 以及 DeepSpeed 的 ZeRO Stage 3 的啟發。FullyShardedDataParallel 通常縮寫為 FSDP。
示例
>>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> torch.cuda.set_device(device_id) >>> sharded_module = FSDP(my_module) >>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) >>> x = sharded_module(x, y=3, z=torch.Tensor([1])) >>> loss = x.sum() >>> loss.backward() >>> optim.step()
使用 FSDP 涉及包裝你的模組,然後在之後初始化你的最佳化器。這是必需的,因為 FSDP 會更改引數變數。
在設定 FSDP 時,你需要考慮目標 CUDA 裝置。如果裝置有 ID (
dev_id),你有三種選擇:將模組放置在該裝置上
使用
torch.cuda.set_device(dev_id)設定裝置將
dev_id傳遞給device_id建構函式引數。
這確保了 FSDP 例項的計算裝置是目標裝置。對於選項 1 和 3,FSDP 初始化始終在 GPU 上進行。對於選項 2,FSDP 初始化發生在模組的當前裝置上,該裝置可能是 CPU。
如果你正在使用
sync_module_states=True標誌,你需要確保模組在 GPU 上,或者使用device_id引數指定一個 FSDP 將在 FSDP 建構函式中移動模組到的 CUDA 裝置。這是必需的,因為sync_module_states=True需要 GPU 通訊。FSDP 還會負責將輸入張量移動到前向方法中,以便進行 GPU 計算,因此你無需手動將它們從 CPU 移動。
對於
use_orig_params=True,ShardingStrategy.SHARD_GRAD_OP會暴露未分片的引數,而不是前向計算後的分片引數,這與ShardingStrategy.FULL_SHARD不同。如果你想檢查梯度,可以使用帶有with_grads=True的summon_full_params方法。當
limit_all_gathers=True時,你可能會在 FSDP 前向計算之前看到 CPU 執行緒沒有發出任何核心的間隙。這是故意的,它顯示了速率限制器的作用。以這種方式同步 CPU 執行緒可以防止為後續的 all-gather 過度分配記憶體,並且實際上不會延遲 GPU 核心的執行。出於與 autograd 相關的原因,FSDP 會在前向和後向計算期間將受管理模組的引數替換為
torch.Tensor檢視。如果你的模組的前向計算依賴於儲存的引數引用而不是在每次迭代時重新獲取引用,那麼它將看不到 FSDP 新建立的檢視,autograd 也將無法正常工作。最後,當使用
sharding_strategy=ShardingStrategy.HYBRID_SHARD並將分片程序組設定為節點內(intra-node),將複製程序組設定為節點間(inter-node)時,設定NCCL_CROSS_NIC=1可以幫助提高某些叢集設定中複製程序組上的 all-reduce 時間。限制
在使用 FSDP 時,有幾點限制需要注意:
在使用 CPU 解除安裝時,FSDP 目前不支援
no_sync()之外的梯度累積。這是因為 FSDP 使用新近減少的梯度,而不是與任何現有梯度累積,這可能導致不正確的結果。FSDP 不支援執行位於 FSDP 例項內的子模組的前向傳遞。這是因為子模組的引數將被分片,但子模組本身不是 FSDP 例項,因此其前向傳遞不會適當地 all-gather 完整的引數。
FSDP 由於其後向鉤註冊方式,無法與雙反向傳播(double backwards)一起工作。
FSDP 在凍結引數方面存在一些約束。對於
use_orig_params=False,每個 FSDP 例項必須管理全部凍結或全部非凍結的引數。對於use_orig_params=True,FSDP 支援混合凍結和非凍結引數,但建議避免這樣做,以防止高於預期的梯度記憶體使用。從 PyTorch 1.12 開始,FSDP 對共享引數的支援有限。如果你的用例需要增強的共享引數支援,請在該 issue 中發帖。
你應該避免在不使用
summon_full_params上下文的情況下修改前向和後向之間的引數,因為修改可能不會持久。
- 引數
module (nn.Module) – 這是需要用 FSDP 包裝的模組。
process_group (Optional[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]) – 這是模型分片的程序組,因此也是 FSDP 的 all-gather 和 reduce-scatter 集合通訊所使用的程序組。如果為
None,則 FSDP 使用預設程序組。對於混合分片策略,例如ShardingStrategy.HYBRID_SHARD,使用者可以傳入一個程序組元組,分別代表分片和複製的組。如果為None,則 FSDP 會為使用者構建程序組,用於節點內分片和節點間複製。(預設:None)sharding_strategy (Optional[ShardingStrategy]) – 這配置了分片策略,可能在記憶體節省和通訊開銷之間進行權衡。有關詳細資訊,請參閱
ShardingStrategy。(預設:FULL_SHARD)cpu_offload (Optional[CPUOffload]) – 這配置了 CPU 解除安裝。如果設定為
None,則不發生 CPU 解除安裝。有關詳細資訊,請參閱CPUOffload。(預設:None)auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], ModuleWrapPolicy, CustomPolicy]]) –
這指定了一個策略,用於將 FSDP 應用於
module的子模組,這對於通訊和計算重疊是必需的,因此會影響效能。如果為None,則 FSDP 只應用於module,使用者應手動將 FSDP 應用於父模組(自底向上進行)。為了方便起見,此引數直接接受ModuleWrapPolicy,它允許使用者指定要包裝的模組類(例如,Transformer 塊)。否則,它應該是一個可呼叫物件,接受三個引數module: nn.Module,recurse: bool, 和nonwrapped_numel: int,並返回一個bool,指定是否應將 FSDP 應用於傳入的module(如果recurse=False),或者如果遍歷應該繼續到模組的子樹(如果recurse=True)。使用者可以為可呼叫物件新增其他引數。torch.distributed.fsdp.wrap.py中的size_based_auto_wrap_policy是一個示例可呼叫物件,當其子樹中的引數超過 1 億個 numel 時,它會將 FSDP 應用於模組。我們建議在應用 FSDP 後列印模型並根據需要進行調整。示例
>>> def custom_auto_wrap_policy( >>> module: nn.Module, >>> recurse: bool, >>> nonwrapped_numel: int, >>> # Additional custom arguments >>> min_num_params: int = int(1e8), >>> ) -> bool: >>> return nonwrapped_numel >= min_num_params >>> # Configure a custom `min_num_params` >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
backward_prefetch (Optional[BackwardPrefetch]) – 這配置了顯式的後向 all-gather 預取。如果為
None,則 FSDP 不進行後向預取,並且在後向傳遞中沒有通訊和計算重疊。有關詳細資訊,請參閱BackwardPrefetch。(預設:BACKWARD_PRE)mixed_precision (Optional[MixedPrecision]) – 這配置了 FSDP 的原生混合精度。如果設定為
None,則不使用混合精度。否則,可以設定引數、緩衝區和梯度減少的 dtype。有關詳細資訊,請參閱MixedPrecision。(預設:None)ignored_modules (Optional[Iterable[torch.nn.Module]]) – 模組本身及其子模組的引數和緩衝區將被此例項忽略。直接位於
ignored_modules中的任何模組都不應是FullyShardedDataParallel例項,並且任何已經是已構造的FullyShardedDataParallel例項的子模組都不會被忽略(如果它們巢狀在此例項下)。此引數可用於避免在為模組粒度使用auto_wrap_policy或當引數的分片不由 FSDP 管理時,忽略特定引數。(預設:None)param_init_fn (Optional[Callable[[nn.Module], None]]) –
一個
Callable[torch.nn.Module] -> None,指定如何將當前位於 meta 裝置上的模組初始化到實際裝置。從 v1.12 開始,FSDP 透過is_meta檢測具有 meta 裝置上的引數或緩衝區的模組,並根據是否指定了param_init_fn來應用它,或者呼叫nn.Module.reset_parameters()。在這兩種情況下,實現都應該*只*初始化模組的引數/緩衝區,而不是其子模組的。這是為了避免重新初始化。此外,FSDP 還支援透過 torchdistX(pytorch/torchdistX)的deferred_init()API 進行延遲初始化,其中延遲模組透過呼叫param_init_fn(如果指定)或 torchdistX 的預設materialize_module()來初始化。如果指定了param_init_fn,則它將應用於所有 meta 裝置上的模組,這意味著它可能需要根據模組型別進行條件判斷。FSDP 在引數展平(flattening)和分片(sharding)之前呼叫初始化函式。示例
>>> module = MyModule(device="meta") >>> def my_init_fn(module: nn.Module): >>> # E.g. initialize depending on the module type >>> ... >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy) >>> print(next(fsdp_model.parameters()).device) # current CUDA device >>> # With torchdistX >>> module = deferred_init.deferred_init(MyModule, device="cuda") >>> # Will initialize via deferred_init.materialize_module(). >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
device_id (Optional[Union[int, torch.device]]) – 一個
int或torch.device,指定 FSDP 初始化發生的 CUDA 裝置,包括必要的模組初始化和引數分片。如果module在 CPU 上,應指定此引數以提高初始化速度。如果設定了預設 CUDA 裝置(例如,透過torch.cuda.set_device),則使用者可以將其傳遞給此引數。(預設:None)sync_module_states (bool) – 如果為
True,則每個 FSDP 模組將從 rank 0 廣播模組引數和緩衝區,以確保它們在 ranks 之間複製(增加了建構函式的通訊開銷)。這有助於以記憶體高效的方式載入state_dict檢查點,例如透過load_state_dict。有關示例,請參閱FullStateDictConfig。(預設:False)forward_prefetch (bool) – 如果為
True,則 FSDP 會在當前前向計算之前*顯式*預取下一個前向 all-gather。這僅對 CPU 密集型工作負載有用,在這種情況下,更早地發出下一個 all-gather 可以改善重疊。這應該只用於靜態圖模型,因為預取遵循第一次迭代的執行順序。(預設:False)limit_all_gathers (bool) – 如果為
True,則 FSDP 會顯式同步 CPU 執行緒,以確保 GPU 記憶體使用僅限於*兩個*連續的 FSDP 例項(當前正在執行計算的例項和正在預取的下一個例項的 all-gather)。如果為False,則 FSDP 允許 CPU 執行緒在沒有任何額外同步的情況下發出 all-gathers。(預設:True)我們經常稱此功能為“速率限制器”。此標誌應僅在特定 CPU 密集型工作負載且記憶體壓力較低時設定為False,此時 CPU 執行緒可以積極發出所有核心,而不必擔心 GPU 記憶體使用。use_orig_params (bool) – 將此設定為
True會使 FSDP 使用module的原始引數。FSDP 透過nn.Module.named_parameters()向用戶公開這些原始引數,而不是 FSDP 內部的FlatParameter。這意味著最佳化器步進在原始引數上執行,從而實現每個原始引數的超引數。FSDP 保留原始引數變數,並在未分片和分片形式之間操縱它們的資料,其中它們始終是底層未分片或分片的FlatParameter的檢視。使用當前演算法,分片形式始終是 1D 的,丟失了原始張量結構。一個原始引數在給定 rank 上可能具有全部、部分或無資料。在無資料的情況下,其資料將類似於大小為 0 的空張量。使用者不應編寫依賴於給定原始引數在分片形式中存在資料的程式。True是使用torch.compile()所必需的。將此設定為False會透過nn.Module.named_parameters()向用戶公開 FSDP 的內部FlatParameter。(預設:False)ignored_states (Optional[Iterable[torch.nn.Parameter], Optional[Iterable[torch.nn.Module]]) – 被此 FSDP 例項忽略的引數或模組,這意味著引數不會被分片,其梯度也不會在 ranks 之間減少。此引數與現有的
ignored_modules引數統一,我們可能會很快棄用ignored_modules。為了向後相容,我們同時保留ignored_states和 ignored_modules`,但 FSDP 只允許其中一個被指定為非None。device_mesh (Optional[DeviceMesh]) – DeviceMesh 可以作為 process_group 的替代品。當傳入 device_mesh 時,FSDP 將使用底層程序組進行 all-gather 和 reduce-scatter 集合通訊。因此,這兩個引數需要互斥。對於混合分片策略,例如
ShardingStrategy.HYBRID_SHARD,使用者可以傳入一個 2D DeviceMesh 而不是一個程序組元組。對於 2D FSDP + TP,使用者需要傳入 device_mesh 而不是 process_group。有關更多 DeviceMesh 資訊,請訪問:https://pytorch.com.tw/tutorials/recipes/distributed_device_mesh.html
- apply(fn)[source]#
將
fn遞迴應用於每個子模組(由.children()返回)以及自身。典型用法包括初始化模型的引數(另請參閱 torch.nn.init)。
與
torch.nn.Module.apply相比,此版本在應用fn之前額外收集了完整引數。不應在另一個summon_full_params上下文內呼叫它。- 引數
fn (
Module-> None) – 要應用於每個子模組的函式- 返回
self
- 返回型別
- clip_grad_norm_(max_norm, norm_type=2.0)[source]#
裁剪所有引數的梯度範數。
範數在所有引數的梯度上計算,這些梯度被視為一個單一向量,並且梯度會就地修改。
- 引數
- 返回
引數的總範數(視為一個單一向量)。
- 返回型別
如果每個 FSDP 例項使用
NO_SHARD,即沒有梯度在 ranks 之間分片,那麼你可以直接使用torch.nn.utils.clip_grad_norm_()。如果至少有一個 FSDP 例項使用分片策略(即非
NO_SHARD),那麼你應該使用此方法而不是torch.nn.utils.clip_grad_norm_(),因為此方法處理了梯度在 ranks 之間分片的事實。返回的總範數將具有所有引數/梯度中“最大”的 dtype,根據 PyTorch 的型別提升語義定義。例如,如果*所有*引數/梯度都使用低精度 dtype,則返回的範數的 dtype 將是該低精度 dtype,但如果至少存在一個引數/梯度使用 FP32,則返回的範數的 dtype 將是 FP32。
警告
這需要在所有 ranks 上呼叫,因為它使用了集合通訊。
- static flatten_sharded_optim_state_dict(sharded_optim_state_dict, model, optim)[source]#
展平(Flatten)分片的最佳化器狀態字典。
該 API 與
shard_full_optim_state_dict()類似。唯一的區別是輸入的sharded_optim_state_dict應該由sharded_optim_state_dict()返回。因此,將會有 all-gather 呼叫,在每個 rank 上收集ShardedTensor。- 引數
sharded_optim_state_dict (Dict[str, Any]) – 對應於未展平引數的最佳化器狀態字典,幷包含分片的最佳化器狀態。
model (torch.nn.Module) – 參見
shard_full_optim_state_dict()。optim (torch.optim.Optimizer) –
model引數的最佳化器。
- 返回
- 返回型別
- static fsdp_modules(module, root_only=False)[source]#
返回所有巢狀的 FSDP 例項。
這可能包括
module本身,並且只有在root_only=True時才包括 FSDP 根模組。- 引數
module (torch.nn.Module) – 根模組,它可能是一個
FSDP模組,也可能不是。root_only (bool) – 是否只返回 FSDP 根模組。(預設:
False)
- 返回
巢狀在輸入
module中的 FSDP 模組。- 返回型別
List[FullyShardedDataParallel]
- static full_optim_state_dict(model, optim, optim_input=None, rank0_only=True, group=None)[source]#
返回完整的最佳化器狀態字典。
在 rank 0 上合併完整的最佳化器狀態,並將其作為
dict返回,遵循torch.optim.Optimizer.state_dict()的約定,即帶有鍵"state"和"param_groups"。包含在model中的FSDP模組中的展平引數被映射回其未展平的引數。這需要在所有 ranks 上呼叫,因為它使用了集合通訊。但是,如果
rank0_only=True,則狀態字典僅在 rank 0 上填充,而所有其他 ranks 返回一個空的dict。與
torch.optim.Optimizer.state_dict()不同,此方法使用完整的引數名稱作為鍵,而不是引數 ID。與
torch.optim.Optimizer.state_dict()一樣,最佳化器狀態字典中包含的張量不會被克隆,因此可能會出現別名陷阱。為了最佳實踐,請考慮立即儲存返回的最佳化器狀態字典,例如使用torch.save()。- 引數
model (torch.nn.Module) – 根模組(可能是一個
FullyShardedDataParallel例項,也可能不是),其引數已傳入最佳化器optim。optim (torch.optim.Optimizer) –
model引數的最佳化器。optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 傳入最佳化器的引數,表示引數組列表或引數的可迭代物件;如果為
None,則此方法假定輸入為model.parameters()。此引數已棄用,不再需要傳遞它。(預設:None)rank0_only (bool) – 如果為
True,則僅在 rank 0 上儲存填充的dict;如果為False,則在所有 ranks 上儲存。(預設:True)group (dist.ProcessGroup) – 模型的程序組或
None,如果使用預設程序組。(預設:None)
- 返回
一個
dict,包含model的最佳化器狀態。最佳化器狀態的分片基於state_dict_type。- 返回型別
Dict[str, Any]
- static get_state_dict_type(module)[source]#
獲取
module的 FSDP 模組後代的所有state_dict_type以及相應的配置。目標模組不必是 FSDP 模組。
- 返回
一個
StateDictSettings,包含當前設定的state_dict_type和state_dict/optim_state_dict配置。- 引發
AssertionError` if the StateDictSettings for differen –
FSDP submodules differ. –
- 返回型別
- named_buffers(*args, **kwargs)[source]#
返回模組緩衝區的迭代器,生成緩衝區的名稱和緩衝區本身。
在
summon_full_params()上下文管理器內部,會攔截緩衝區名稱並移除 FSDP 特定的展平緩衝區字首的所有出現。- 返回型別
- named_parameters(*args, **kwargs)[source]#
返回模組引數的迭代器,生成引數的名稱和引數本身。
攔截引數名稱,並在
summon_full_params()上下文管理器內部,移除 FSDP 特定的展平引數字首的所有出現。- 返回型別
- no_sync()[source]#
停用 FSDP 例項之間的梯度同步。
在此上下文內,梯度將累積在模組變數中,稍後將在退出上下文後的第一個前向-後向傳遞中進行同步。這應該只在根 FSDP 例項上使用,並將遞迴應用於所有子 FSDP 例項。
注意
這可能會導致更高的記憶體使用,因為 FSDP 將累積完整的模型梯度(而不是梯度分片),直到最終同步。
注意
當與 CPU 解除安裝一起使用時,梯度在上下文管理器內部不會解除安裝到 CPU。相反,它們只會在最終同步後立即解除安裝。
- 返回型別
- static optim_state_dict(model, optim, optim_state_dict=None, group=None)[source]#
轉換分片模型的最佳化器狀態字典。
給定的狀態字典可以轉換為三種類型之一:1)完整的最佳化器狀態字典,2)分片最佳化器狀態字典,3)本地最佳化器狀態字典。
對於完整的最佳化器狀態字典,所有狀態都未展平且未分片。可以透過
state_dict_type()指定僅 rank0 和僅 CPU 來避免 OOM。對於分片最佳化器狀態字典,所有狀態都未展平但已分片。可以透過
state_dict_type()指定僅 CPU 來進一步節省記憶體。對於本地狀態字典,不進行轉換。但是,狀態將從 nn.Tensor 轉換為 ShardedTensor 以表示其分片性質(這尚未支援)。
示例
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> from torch.distributed.fsdp import FullStateDictConfig >>> from torch.distributed.fsdp import FullOptimStateDictConfig >>> # Save a checkpoint >>> model, optim = ... >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> state_dict = model.state_dict() >>> optim_state_dict = FSDP.optim_state_dict(model, optim) >>> save_a_checkpoint(state_dict, optim_state_dict) >>> # Load a checkpoint >>> model, optim = ... >>> state_dict, optim_state_dict = load_a_checkpoint() >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> model.load_state_dict(state_dict) >>> optim_state_dict = FSDP.optim_state_dict_to_load( >>> model, optim, optim_state_dict >>> ) >>> optim.load_state_dict(optim_state_dict)
- 引數
model (torch.nn.Module) – 根模組(可能是一個
FullyShardedDataParallel例項,也可能不是),其引數已傳入最佳化器optim。optim (torch.optim.Optimizer) –
model引數的最佳化器。optim_state_dict (Dict[str, Any]) – 要轉換的目標最佳化器狀態字典。如果值為 None,則使用 optim.state_dict()。(預設:
None)group (dist.ProcessGroup) – 模型跨其引數分片的程序組,或者在使用預設程序組時為
None。(預設:None)
- 返回
一個
dict,包含model的最佳化器狀態。最佳化器狀態的分片基於state_dict_type。- 返回型別
Dict[str, Any]
- static optim_state_dict_to_load(model, optim, optim_state_dict, is_named_optimizer=False, load_directly=False, group=None)[source]#
將最佳化器狀態字典轉換為可以載入到與 FSDP 模型關聯的最佳化器中的格式。
給定一個透過
optim_state_dict()轉換的optim_state_dict,它被轉換為展平的最佳化器狀態字典,可以載入到optim中,而optim是model的最佳化器。model必須由 FullyShardedDataParallel 分片。>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> from torch.distributed.fsdp import FullStateDictConfig >>> from torch.distributed.fsdp import FullOptimStateDictConfig >>> # Save a checkpoint >>> model, optim = ... >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> state_dict = model.state_dict() >>> original_osd = optim.state_dict() >>> optim_state_dict = FSDP.optim_state_dict( >>> model, >>> optim, >>> optim_state_dict=original_osd >>> ) >>> save_a_checkpoint(state_dict, optim_state_dict) >>> # Load a checkpoint >>> model, optim = ... >>> state_dict, optim_state_dict = load_a_checkpoint() >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> model.load_state_dict(state_dict) >>> optim_state_dict = FSDP.optim_state_dict_to_load( >>> model, optim, optim_state_dict >>> ) >>> optim.load_state_dict(optim_state_dict)
- 引數
model (torch.nn.Module) – 根模組(可能是一個
FullyShardedDataParallel例項,也可能不是),其引數已傳入最佳化器optim。optim (torch.optim.Optimizer) –
model引數的最佳化器。optim_state_dict (Dict[str, Any]) – 要載入的最佳化器狀態。
is_named_optimizer (bool) – 這個最佳化器是 NamedOptimizer 還是 KeyedOptimizer。僅當
optim是 TorchRec 的 KeyedOptimizer 或 torch.distributed 的 NamedOptimizer 時才設定為 True。load_directly (bool) – 如果設定為 True,此 API 還將在返回結果之前呼叫 optim.load_state_dict(result)。否則,使用者負責呼叫
optim.load_state_dict()。(預設:False)group (dist.ProcessGroup) – 模型跨其引數分片的程序組,或者在使用預設程序組時為
None。(預設:None)
- 返回型別
- register_comm_hook(state, hook)[source]#
註冊一個通訊鉤子。
這是一個增強功能,為使用者提供了一個靈活的鉤子,他們可以在其中指定 FSDP 如何聚合多個工作程序的梯度。此鉤子可用於實現幾種演算法,例如 GossipGrad 和梯度壓縮,這些演算法在與
FullyShardedDataParallel訓練時涉及不同的引數同步通訊策略。警告
FSDP 通訊鉤子應在執行初始前向傳遞之前註冊,並且只註冊一次。
- 引數
state (object) –
傳遞給鉤子,用於在訓練過程中維護任何狀態資訊。例如,包括梯度壓縮中的錯誤反饋,GossipGrad 中下一個要通訊的對等方等。它在每個工作程序上本地儲存,並由工作程序上的所有梯度張量共享。
hook (Callable) – 可呼叫物件,其簽名如下:1)
hook: Callable[torch.Tensor] -> None:此函式接受一個 Python 張量,它代表與此 FSDP 單元包裝的模型(未被其他 FSDP 子單元包裝的)對應的所有變數的完整、展平、未分片梯度。然後執行所有必要的處理並返回None;2)hook: Callable[torch.Tensor, torch.Tensor] -> None:此函式接受兩個 Python 張量,第一個代表與此 FSDP 單元包裝的模型(未被其他 FSDP 子單元包裝的)對應的所有變數的完整、展平、未分片梯度。第二個代表一個預先調整大小的張量,用於儲存減少後的分片梯度的塊。在這兩種情況下,可呼叫物件都會執行所有必要的處理並返回None。簽名 1 的可呼叫物件預計會處理 NO_SHARD 情況下的梯度通訊。簽名 2 的可呼叫物件預計會處理分片情況下的梯度通訊。
- static rekey_optim_state_dict(optim_state_dict, optim_state_key_type, model, optim_input=None, optim=None)[source]#
將最佳化器狀態字典
optim_state_dict重新鍵入以使用optim_state_key_type。這可用於實現具有 FSDP 例項的模型和不具有 FSDP 例項的模型之間最佳化器狀態字典的相容性。
要將 FSDP 完整最佳化器狀態字典(即來自
full_optim_state_dict())重新鍵入為使用引數 ID 並使其可載入到非包裝模型中>>> wrapped_model, wrapped_optim = ... >>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim) >>> nonwrapped_model, nonwrapped_optim = ... >>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model) >>> nonwrapped_optim.load_state_dict(rekeyed_osd)
要將來自非包裝模型的普通最佳化器狀態字典重新鍵入為可載入到包裝模型中
>>> nonwrapped_model, nonwrapped_optim = ... >>> osd = nonwrapped_optim.state_dict() >>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model) >>> wrapped_model, wrapped_optim = ... >>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model) >>> wrapped_optim.load_state_dict(sharded_osd)
- 返回
使用
optim_state_key_type指定的引數鍵重新鍵入的最佳化器狀態字典。- 返回型別
Dict[str, Any]
- static scatter_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None, group=None)[source]#
將完整的最佳化器狀態字典從 rank 0 散佈(scatter)到所有其他 ranks。
在每個 rank 上返回分片最佳化器狀態字典。返回值與
shard_full_optim_state_dict()相同,並且在 rank 0 上,第一個引數應該是full_optim_state_dict()的返回值。示例
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0 >>> # Define new model with possibly different world size >>> new_model, new_optim, new_group = ... >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) >>> new_optim.load_state_dict(sharded_osd)
注意
可以使用
shard_full_optim_state_dict()和scatter_full_optim_state_dict()來獲取要載入的分片最佳化器狀態字典。假設完整的最佳化器狀態字典存在於 CPU 記憶體中,前者要求每個 rank 在 CPU 記憶體中都有完整的字典,每個 rank 單獨對字典進行分片而無需任何通訊;而後者僅要求 rank 0 在 CPU 記憶體中擁有完整的字典,rank 0 將每個分片移動到 GPU 記憶體(用於 NCCL)並將其適當地通訊給 ranks。因此,前者具有更高的聚合 CPU 記憶體成本,而後者具有更高的通訊成本。- 引數
full_optim_state_dict (Optional[Dict[str, Any]]) – 對應於未展平引數的最佳化器狀態字典,並在 rank 0 上包含完整的非分片最佳化器狀態;引數在非零 ranks 上被忽略。
model (torch.nn.Module) – 根模組(可能是一個
FullyShardedDataParallel例項,也可能不是),其引數對應於full_optim_state_dict中的最佳化器狀態。optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 傳入最佳化器的引數,表示引數組列表或引數的可迭代物件;如果為
None,則此方法假定輸入為model.parameters()。此引數已棄用,不再需要傳遞它。(預設:None)optim (Optional[torch.optim.Optimizer]) – 將載入此方法返回的狀態字典的最佳化器。這是比
optim_input優先使用的引數。(預設:None)group (dist.ProcessGroup) – 模型的程序組或
None,如果使用預設程序組。(預設:None)
- 返回
完整的最佳化器狀態字典現在已重新對映到展平的引數而不是未展平的引數,並且僅限於此 rank 的最佳化器狀態部分。
- 返回型別
Dict[str, Any]
- static set_state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source]#
設定目標模組所有後代 FSDP 模組的
state_dict_type。還接受模型和最佳化器狀態字典的可選配置。目標模組不必是 FSDP 模組。如果目標模組是 FSDP 模組,其
state_dict_type也將被更改。注意
此 API 應僅為頂層(根)模組呼叫。
注意
此 API 使能使用者透明地使用傳統的
state_dictAPI 來獲取模型檢查點,在根 FSDP 模組被另一個nn.Module包裝的情況下。例如,以下程式碼將確保對所有非 FSDP 例項呼叫state_dict,同時將 FSDP 例項分派到 sharded_state_dict 實現。示例
>>> model = DDP(FSDP(...)) >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.SHARDED_STATE_DICT, >>> state_dict_config = ShardedStateDictConfig(offload_to_cpu=True), >>> optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True), >>> ) >>> param_state_dict = model.state_dict() >>> optim_state_dict = FSDP.optim_state_dict(model, optim)
- 引數
module (torch.nn.Module) – 根模組。
state_dict_type (StateDictType) – 要設定的
state_dict_type。state_dict_config (Optional[StateDictConfig]) – 目標
state_dict_type的配置。optim_state_dict_config (Optional[OptimStateDictConfig]) – 最佳化器狀態字典的配置。
- 返回
一個 StateDictSettings,包含模組的先前
state_dict型別和配置。- 返回型別
- static shard_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None)[source]#
分片(Shard)一個完整的最佳化器狀態字典。
將
full_optim_state_dict中的狀態重新對映到展平的引數而不是未展平的引數,並將結果限制為該 rank 的最佳化器狀態部分。第一個引數應該是full_optim_state_dict()的返回值。示例
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) >>> torch.save(full_osd, PATH) >>> # Define new model with possibly different world size >>> new_model, new_optim = ... >>> full_osd = torch.load(PATH) >>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model) >>> new_optim.load_state_dict(sharded_osd)
注意
可以使用
shard_full_optim_state_dict()和scatter_full_optim_state_dict()來獲取要載入的分片最佳化器狀態字典。假設完整的最佳化器狀態字典存在於 CPU 記憶體中,前者要求每個 rank 在 CPU 記憶體中都有完整的字典,每個 rank 單獨對字典進行分片而無需任何通訊;而後者僅要求 rank 0 在 CPU 記憶體中擁有完整的字典,rank 0 將每個分片移動到 GPU 記憶體(用於 NCCL)並將其適當地通訊給 ranks。因此,前者具有更高的聚合 CPU 記憶體成本,而後者具有更高的通訊成本。- 引數
full_optim_state_dict (Dict[str, Any]) – 對應於未展平引數的最佳化器狀態字典,幷包含完整的非分片最佳化器狀態。
model (torch.nn.Module) – 根模組(可能是一個
FullyShardedDataParallel例項,也可能不是),其引數對應於full_optim_state_dict中的最佳化器狀態。optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 傳入最佳化器的引數,表示引數組列表或引數的可迭代物件;如果為
None,則此方法假定輸入為model.parameters()。此引數已棄用,不再需要傳遞它。(預設:None)optim (Optional[torch.optim.Optimizer]) – 將載入此方法返回的狀態字典的最佳化器。這是比
optim_input優先使用的引數。(預設:None)
- 返回
完整的最佳化器狀態字典現在已重新對映到展平的引數而不是未展平的引數,並且僅限於此 rank 的最佳化器狀態部分。
- 返回型別
Dict[str, Any]
- static sharded_optim_state_dict(model, optim, group=None)[source]#
以分片形式返回最佳化器狀態字典。
該 API 與
full_optim_state_dict()類似,但此 API 將所有非零維狀態分塊為ShardedTensor以節省記憶體。此 API 應僅在使用with state_dict_type(SHARDED_STATE_DICT):上下文管理器派生模型state_dict時使用。有關詳細用法,請參閱
full_optim_state_dict()。警告
返回的狀態字典包含
ShardedTensor,不能直接由常規optim.load_state_dict使用。
- static state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[source]#
設定目標模組所有後代 FSDP 模組的
state_dict_type。這個上下文管理器具有與
set_state_dict_type()相同的功能。有關詳細資訊,請閱讀set_state_dict_type()的文件。示例
>>> model = DDP(FSDP(...)) >>> with FSDP.state_dict_type( >>> model, >>> StateDictType.SHARDED_STATE_DICT, >>> ): >>> checkpoint = model.state_dict()
- 引數
module (torch.nn.Module) – 根模組。
state_dict_type (StateDictType) – 要設定的
state_dict_type。state_dict_config (Optional[StateDictConfig]) – 目標
state_dict_type的模型state_dict配置。optim_state_dict_config (Optional[OptimStateDictConfig]) – 目標
state_dict_type的最佳化器state_dict配置。
- 返回型別
- static summon_full_params(module, recurse=True, writeback=True, rank0_only=False, offload_to_cpu=False, with_grads=False)[source]#
使用此上下文管理器為 FSDP 例項暴露完整引數。
在模型的前向/後向計算*之後*可能有用,以獲取引數以進行額外的處理或檢查。它可以接受非 FSDP 模組,並根據
recurse引數,為所有包含的 FSDP 模組及其子模組呼叫完整引數。注意
這可以用於內部 FSDP。
注意
此上下文不能在(或從)前向或後向傳遞中使用。也不能在此上下文內啟動前向和後向。
注意
退出上下文管理器後,引數將恢復為它們的本地分片,儲存行為與前向傳遞相同。
注意
完整引數可以被修改,但只有對應於本地引數分片的部分在退出上下文管理器後才會保留(除非
writeback=False,在這種情況下,更改將被丟棄)。在 FSDP 不分片引數的情況下(例如,僅當world_size == 1或配置為NO_SHARD時),修改將保留,無論writeback如何。注意
此方法適用於本身不是 FSDP 但可能包含多個獨立 FSDP 單元的模組。在這種情況下,給定的引數將應用於所有包含的 FSDP 單元。
警告
請注意,
rank0_only=True與writeback=True結合使用目前不支援,並且會引發錯誤。這是因為在上下文內的模型引數形狀在 ranks 之間會有所不同,在上下文退出時寫入它們可能導致 ranks 之間的不一致。警告
請注意,
offload_to_cpu和rank0_only=False將導致完整引數被冗餘地複製到位於同一臺機器上的 GPU 的 CPU 記憶體中,這可能會帶來 CPU OOM 的風險。建議將offload_to_cpu與rank0_only=True一起使用。- 引數
recurse (布林值, 可選) – 遞迴呼叫所有引數以獲取巢狀的FSDP例項(預設值:True)。
writeback (布林值, 可選) – 如果為
False,則在退出上下文管理器後會丟棄對引數的修改;停用此項可以稍微提高效率(預設值:True)rank0_only (布林值, 可選) – 如果為
True,則僅在全域性rank 0上例項化完整的引數。這意味著在上下文中,只有rank 0擁有完整的引數,而其他rank則擁有分片(sharded)的引數。請注意,同時設定rank0_only=True和writeback=True是不支援的,因為在上下文內,模型引數在不同rank上的形狀會不同,並且寫入這些引數可能導致退出上下文時跨rank的不一致。offload_to_cpu (布林值, 可選) – 如果為
True,則將完整的引數解除安裝到 CPU。請注意,此解除安裝目前僅發生在引數被分片時(這僅在 world_size = 1 或NO_SHARD配置時不是這種情況)。建議將offload_to_cpu與rank0_only=True一起使用,以避免將模型引數冗餘地複製到相同的 CPU 記憶體中。with_grads (布林值, 可選) – 如果為
True,則梯度也會隨引數一起解分片(unsharded)。目前,這僅在向 FSDP 建構函式傳遞use_orig_params=True並且向此方法傳遞offload_to_cpu=False時才受支援。(預設值:False)
- 返回型別
- class torch.distributed.fsdp.BackwardPrefetch(value)[source]#
此配置用於顯式的向後預取(backward prefetching),透過在反向傳播中實現通訊和計算重疊來提高吞吐量,但會略微增加記憶體使用。
BACKWARD_PRE:這提供了最大的重疊,但記憶體使用也最多。它會在當前引數集計算梯度*之前*預取下一組引數。這會重疊*下一個 all-gather* 和*當前梯度計算*,在峰值時,它會在記憶體中同時儲存當前引數集、下一組引數集以及當前梯度集。BACKWARD_POST:這提供的重疊較少,但記憶體使用也較少。它會在當前引數集計算梯度*之後*預取下一組引數。這會重疊*當前 reduce-scatter* 和*下一個梯度計算*,並且在為下一組引數分配記憶體之前釋放當前引數集,在峰值時僅在記憶體中保留下一組引數和當前梯度集。FSDP 的
backward_prefetch引數可以接受None,這將完全停用向後預取。這沒有重疊,也不會增加記憶體使用。一般來說,我們不推薦這種設定,因為它可能會顯著降低吞吐量。
更多技術背景:對於使用 NCCL 後端的單個程序組,任何集體通訊(collective operations),即使是從不同的流發出的,都會爭用相同的每裝置 NCCL 流,這意味著集體通訊發出的相對順序對於重疊很重要。兩個向後預取值對應於不同的發出順序。
- class torch.distributed.fsdp.ShardingStrategy(value)[source]#
這指定了由
FullyShardedDataParallel使用的分散式訓練分片策略。FULL_SHARD:引數、梯度和最佳化器狀態都會被分片。對於引數,此策略在前向傳播之前進行解分片(透過 all-gather),在前向傳播之後進行重分片,在反向傳播計算之前進行解分片,並在反向傳播計算之後進行重分片。對於梯度,它在反向傳播計算之後進行同步和分片(透過 reduce-scatter)。分片後的最佳化器狀態在每個rank上本地更新。SHARD_GRAD_OP:梯度和最佳化器狀態在計算過程中被分片,此外,引數在計算之外也被分片。對於引數,此策略在前向傳播之前進行解分片,在前向傳播之後不進行重分片,僅在反向傳播計算之後進行重分片。分片後的最佳化器狀態在每個rank上本地更新。在no_sync()內部,引數在反向傳播計算之後不會被重分片。NO_SHARD:引數、梯度和最佳化器狀態都不會被分片,而是像 PyTorch 的DistributedDataParallelAPI 一樣在各個rank上進行復制。對於梯度,此策略在反向傳播計算之後透過 all-reduce 進行同步。未分片的最佳化器狀態在每個rank上本地更新。HYBRID_SHARD:在節點內應用FULL_SHARD,並在節點之間複製引數。這減少了通訊量,因為昂貴的 all-gather 和 reduce-scatter 操作僅在節點內完成,這對於中等大小的模型可能更具效能優勢。_HYBRID_SHARD_ZERO2:在節點內應用SHARD_GRAD_OP,並在節點之間複製引數。這類似於HYBRID_SHARD,但由於未分片的引數在前向傳播後不會被釋放,從而節省了反向傳播之前的 all-gather 操作,因此可能提供更高的吞吐量。
- class torch.distributed.fsdp.MixedPrecision(param_dtype=None, reduce_dtype=None, buffer_dtype=None, keep_low_precision_grads=False, cast_forward_inputs=False, cast_root_forward_inputs=True, _module_classes_to_ignore=(<class 'torch.nn.modules.batchnorm._BatchNorm'>, ))[source]#
此配置用於 FSDP 原生的混合精度訓練。
- 變數
param_dtype (可選[torch.dtype]) – 此引數指定模型引數在前向和後向傳播期間的資料型別,因此也指定了前向和後向計算的資料型別。在前向和後向傳播之外,*分片*的引數以全精度(例如,用於最佳化器步驟)保留,並且在模型檢查點(checkpointing)時,引數始終以全精度儲存。(預設值:
None)reduce_dtype (可選[torch.dtype]) – 此引數指定梯度約簡(即 reduce-scatter 或 all-reduce)的資料型別。如果此引數為
None但param_dtype不是None,則此引數將採用param_dtype的值,仍然以低精度執行梯度約簡。此引數可以與param_dtype不同,例如,強制梯度約簡以全精度執行。(預設值:None)buffer_dtype (可選[torch.dtype]) – 此引數指定緩衝區(buffers)的資料型別。FSDP 不分片緩衝區。相反,FSDP 在第一次前向傳遞時將其轉換為
buffer_dtype,並在之後一直保持該資料型別。對於模型檢查點,除了LOCAL_STATE_DICT外,緩衝區都以全精度儲存。(預設值:None)keep_low_precision_grads (布林值) – 如果為
False,則 FSDP 在反向傳播後將梯度向上轉換為全精度,以準備最佳化器步驟。如果為True,則 FSDP 將梯度保留在用於梯度約簡的資料型別中,如果使用支援低精度執行的自定義最佳化器,這可以節省記憶體。(預設值:False)cast_forward_inputs (布林值) – 如果為
True,則此 FSDP 模組將其前向引數和關鍵字引數轉換為param_dtype。這是為了確保引數和輸入的資料型別在前向計算時匹配,這是許多操作的要求。當僅將混合精度應用於部分但非全部 FSDP 模組時,可能需要將此設定為True,在這種情況下,一個混合精度 FSDP 子模組需要重新轉換其輸入。(預設值:False)cast_root_forward_inputs (布林值) – 如果為
True,則根 FSDP 模組將其前向引數和關鍵字引數轉換為param_dtype,這將覆蓋cast_forward_inputs的值。對於非根 FSDP 模組,此設定無效。(預設值:True)_module_classes_to_ignore (collections.abc.Sequence[type[torch.nn.modules.module.Module]]) – (Sequence[Type[nn.Module]]):此引數指定在使用
auto_wrap_policy時要忽略混合精度的模組類別:這些類別的模組將分別應用 FSDP,並停用混合精度(這意味著最終的 FSDP 構建將偏離指定的策略)。如果未指定auto_wrap_policy,則此引數無效。此 API 是實驗性的,可能會發生更改。(預設值:(_BatchNorm,))
注意
此 API 是實驗性的,可能會發生更改。
注意
只有浮點張量會被轉換為指定的資料型別。
注意
在
summon_full_params中,引數會被強制轉換為全精度,但緩衝區不會。注意
層歸一化(Layer norm)和批歸一化(Batch norm)即使在輸入資料型別為
float16或bfloat16等低精度時,也會以float32累積。停用這些歸一化模組的 FSDP 混合精度僅意味著仿射(affine)引數會保留為float32。然而,這會導致這些歸一化模組進行單獨的 all-gather 和 reduce-scatter 操作,這可能效率不高。因此,如果工作負載允許,使用者應優先考慮仍將混合精度應用於這些模組。注意
預設情況下,如果使用者傳遞的模型包含任何
_BatchNorm模組並指定了auto_wrap_policy,那麼批歸一化模組將單獨應用 FSDP,並停用混合精度。請參閱_module_classes_to_ignore引數。注意
MixedPrecision預設具有cast_root_forward_inputs=True和cast_forward_inputs=False。對於根 FSDP 例項,其cast_root_forward_inputs優先於其cast_forward_inputs。對於非根 FSDP 例項,其cast_root_forward_inputs值將被忽略。預設設定足以應對典型情況,即每個 FSDP 例項具有相同的MixedPrecision配置,並且在模型前向傳遞的開始處僅需要將輸入轉換為param_dtype。注意
對於具有不同
MixedPrecision配置的巢狀 FSDP 例項,我們建議將各個cast_forward_inputs值設定為True或False,以配置在每個例項的前向傳播之前是否進行輸入轉換。在這種情況下,由於轉換髮生在每個 FSDP 例項的前向傳播之前,父 FSDP 例項應在其非 FSDP 子模組之後執行其 FSDP 子模組,以避免由於不同的MixedPrecision配置而導致啟用資料型別發生更改。示例
>>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) >>> model[1] = FSDP( >>> model[1], >>> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True), >>> ) >>> model = FSDP( >>> model, >>> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), >>> )
上述示例是一個可行的示例。另一方面,如果將
model[1]替換為model[0],這意味著使用不同MixedPrecision的子模組首先執行其前向傳播,那麼model[1]將錯誤地看到float16啟用而不是bfloat16啟用。
- class torch.distributed.fsdp.CPUOffload(offload_params=False)[source]#
此配置用於 CPU 解除安裝。
- 變數
offload_params (布林值) – 此引數指定在不參與計算時是否將引數解除安裝到 CPU。如果為
True,則還會將梯度解除安裝到 CPU,這意味著最佳化器步驟將在 CPU 上執行。
- class torch.distributed.fsdp.StateDictConfig(offload_to_cpu=False)[source]#
StateDictConfig是所有state_dict配置類的基類。使用者應例項化一個子類(例如FullStateDictConfig),以便配置 FSDP 支援的相應state_dict型別的設定。- 變數
offload_to_cpu (布林值) – 如果為
True,則 FSDP 將 state dict 的值解除安裝到 CPU;如果為False,則 FSDP 將其保留在 GPU 上。(預設值:False)
- class torch.distributed.fsdp.FullStateDictConfig(offload_to_cpu=False, rank0_only=False)[source]#
FullStateDictConfig是一個配置類,用於StateDictType.FULL_STATE_DICT。我們建議在儲存完整的 state dict 時同時啟用offload_to_cpu=True和rank0_only=True,分別節省 GPU 記憶體和 CPU 記憶體。此類應透過state_dict_type()上下文管理器使用,如下所示:>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> fsdp = FSDP(model, auto_wrap_policy=...) >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) >>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): >>> state = fsdp.state_dict() >>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0. >>> # To reload checkpoint for inference, finetuning, transfer learning, etc: >>> model = model_fn() # Initialize model in preparation for wrapping with FSDP >>> if dist.get_rank() == 0: >>> # Load checkpoint only on rank 0 to avoid memory redundancy >>> state_dict = torch.load("my_checkpoint.pt") >>> model.load_state_dict(state_dict) >>> # All ranks initialize FSDP module as usual. `sync_module_states` argument >>> # communicates loaded checkpoint states from rank 0 to rest of the world. >>> fsdp = FSDP( ... model, ... device_id=torch.cuda.current_device(), ... auto_wrap_policy=..., ... sync_module_states=True, ... ) >>> # After this point, all ranks have FSDP model with loaded checkpoint.
- 變數
rank0_only (布林值) – 如果為
True,則僅 rank 0 儲存完整的 state dict,非零 rank 儲存一個空字典。如果為False,則所有 rank 儲存完整的 state dict。(預設值:False)
- class torch.distributed.fsdp.ShardedStateDictConfig(offload_to_cpu=False, _use_dtensor=False)[source]#
ShardedStateDictConfig是一個配置類,用於StateDictType.SHARDED_STATE_DICT。- 變數
_use_dtensor (布林值) – 如果為
True,則 FSDP 將 state dict 的值儲存為DTensor;如果為False,則 FSDP 將其儲存為ShardedTensor。(預設值:False)
警告
_use_dtensor是ShardedStateDictConfig的一個私有欄位,它被 FSDP 用來確定 state dict 值的型別。使用者不應手動修改_use_dtensor。
- class torch.distributed.fsdp.OptimStateDictConfig(offload_to_cpu=True)[source]#
OptimStateDictConfig是所有optim_state_dict配置類的基類。使用者應例項化一個子類(例如FullOptimStateDictConfig),以便配置 FSDP 支援的相應optim_state_dict型別的設定。- 變數
offload_to_cpu (布林值) – 如果為
True,則 FSDP 將 state dict 的張量值解除安裝到 CPU;如果為False,則 FSDP 將其保留在原始裝置上(除非啟用了引數 CPU 解除安裝,否則為 GPU)。(預設值:True)
- class torch.distributed.fsdp.FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False)[source]#
- 變數
rank0_only (布林值) – 如果為
True,則僅 rank 0 儲存完整的 state dict,非零 rank 儲存一個空字典。如果為False,則所有 rank 儲存完整的 state dict。(預設值:False)
- class torch.distributed.fsdp.ShardedOptimStateDictConfig(offload_to_cpu=True, _use_dtensor=False)[source]#
ShardedOptimStateDictConfig是一個配置類,用於StateDictType.SHARDED_STATE_DICT。- 變數
_use_dtensor (布林值) – 如果為
True,則 FSDP 將 state dict 的值儲存為DTensor;如果為False,則 FSDP 將其儲存為ShardedTensor。(預設值:False)
警告
_use_dtensor是ShardedOptimStateDictConfig的一個私有欄位,它被 FSDP 用來確定 state dict 值的型別。使用者不應手動修改_use_dtensor。
- class torch.distributed.fsdp.StateDictSettings(state_dict_type: torch.distributed.fsdp.api.StateDictType, state_dict_config: torch.distributed.fsdp.api.StateDictConfig, optim_state_dict_config: torch.distributed.fsdp.api.OptimStateDictConfig)[source]#