注意
轉到末尾 下載完整的示例程式碼。
(原型)使用 GPUDirect Storage 加速 torch.save 和 torch.load#
GPUDirect Storage 實現了 GPU 記憶體和儲存之間的直接記憶體訪問傳輸的直接資料路徑,避免了透過 CPU 的短暫緩衝。
在 **2.7** 版本中,我們引入了 torch.cuda.gds 的新原型 API,它們是 cuFile API 的輕量級封裝,可與 torch.Tensor 一起使用,以提高 I/O 效能。
在本教程中,我們將演示如何在本地檔案系統上使用 torch.cuda.gds API 結合 torch.save 和 torch.load 生成的檢查點。
瞭解如何在本地檔案系統上使用
torch.cuda.gdsAPI 結合torch.save和torch.load生成的檢查點
PyTorch v.2.7.0 或更高版本
必須根據 文件 安裝 GPUDirect Storage
確保您正在儲存/載入到的檔案系統支援 GPUDirect Storage。
將 GPUDirect Storage 與 torch.save 和 torch.load 結合使用#
GPUDirect Storage 需要 4KB 的儲存對齊。您可以使用 torch.utils.serialization.config.save.storage_alignment 來切換此設定。
import torch
from torch.utils.serialization import config as serialization_config
serialization_config.save.storage_alignment = 4096
- 該過程涉及的步驟如下:
寫入檢查點檔案,而不寫入任何實際資料。這會在磁碟上預留空間。
使用
FakeTensor讀取檢查點中與每個張量關聯的儲存的偏移量。使用
GDSFile在這些偏移量處寫入相應的資料。
給定一個位於 GPU 上的張量狀態字典,可以使用 torch.serialization.skip_data 上下文管理器來儲存一個檢查點,該檢查點包含除儲存位元組以外的所有相關元資料。對於狀態字典中的每個 torch.Storage,將在檢查點內為儲存位元組預留空間。
import torch.nn as nn
m = nn.Linear(5, 10, device='cuda')
sd = m.state_dict()
with torch.serialization.skip_data():
torch.save(sd, "checkpoint.pt")
我們可以透過在 FakeTensorMode 下載入來獲取每個儲存在檢查點內應寫入的偏移量。FakeTensor 是一個具有張量元資料(如大小、步幅、dtype、裝置)但沒有儲存位元組的張量。以下程式碼片段不會具體化任何資料,但會將每個 FakeTensor 標記為在檢查點內與該張量對應的偏移量。
如果您在訓練過程中持續儲存相同的狀態字典,則只需獲取一次偏移量,並且可以重複使用相同的偏移量。同樣,如果一個張量將被重複儲存或載入,您可以使用 torch.cuda.gds.gds_register_buffer,它封裝了 cuFileBufRegister 以將儲存註冊為 GDS 緩衝區。
請注意,torch.cuda.gds.GdsFile.save_storage 繫結到同步 cuFileWrite API,因此之後不需要進行同步。
import os
from torch._subclasses.fake_tensor import FakeTensorMode
with FakeTensorMode() as mode:
fake_sd = torch.load("checkpoint.pt")
for k, v in fake_sd.items():
print(f"key={k}, offset={v.untyped_storage()._checkpoint_offset}")
f = torch.cuda.gds.GdsFile("checkpoint.pt", os.O_RDWR)
for k, v in sd.items():
offset = fake_sd[k].untyped_storage()._checkpoint_offset
# save_storage is a wrapper around `cuFileWrite`
f.save_storage(v.untyped_storage(), offset)
我們透過 torch.load 並進行比較來驗證儲存的檢查點的正確性。
sd_loaded = torch.load("checkpoint.pt")
for k, v in sd_loaded.items():
assert torch.equal(v, sd[k])
載入流程是反向的:您可以使用 torch.load 和 torch.serialization.skip_data 上下文管理器來載入除儲存位元組外的所有內容。這意味著檢查點中的任何張量都將被建立,但它們的儲存將是空的(就像張量是透過 torch.empty 建立的一樣)。
with torch.serialization.skip_data():
sd_loaded = torch.load("checkpoint.pt")
我們再次使用 FakeTensorMode 來獲取檢查點偏移量,並確定載入的檢查點與儲存的檢查點相同。
與 torch.cuda.gds.GdsFile.save_storage 類似,torch.cuda.gds.GdsFile.load_storage 繫結到同步 cuFileRead API,因此之後不需要進行同步。
for k, v in sd_loaded.items():
assert not torch.equal(v, sd[k])
offset = fake_sd[k].untyped_storage()._checkpoint_offset
# load_storage is a wrapper around `cuFileRead`
f.load_storage(v.untyped_storage(), offset)
for k, v in sd_loaded.items():
assert torch.equal(v, sd[k])
del f
結論#
在本教程中,我們演示瞭如何在本地檔案系統上使用原型 torch.cuda.gds API 結合 torch.save 和 torch.load。如果您有任何反饋,請在 PyTorch GitHub 倉庫中提交一個 issue。