TensorDict 在分散式設定中¶
TensorDict 可以在分散式設定中使用,用於在節點之間傳遞張量。如果兩個節點可以訪問共享的物理儲存,則可以使用記憶體對映張量(memory-mapped tensor)在執行中的程序之間高效地傳遞資料。在此,我們提供一些關於如何在分散式 RPC 設定中實現這一點的資訊。有關分散式 RPC 的更多詳細資訊,請參閱 官方 pytorch 文件。
建立記憶體對映的 TensorDict¶
記憶體對映張量(和陣列)的一個巨大優點是它們可以儲存大量資料,並允許隨時訪問資料的切片,而無需將整個檔案讀入記憶體。TensorDict 在記憶體對映陣列和 torch.Tensor 類之間提供了一個名為 MemmapTensor 的介面。MemmapTensor 例項可以儲存在 TensorDict 物件中,從而允許 tensordict 表示儲存在磁碟上的大資料集,並可在節點之間以批處理的方式輕鬆訪問。
記憶體對映的 tensordict 可以透過以下方式建立:(1) 使用記憶體對映張量填充 TensorDict,或者 (2) 呼叫 tensordict.memmap_() 將其放入物理儲存。可以透過查詢 tensordict.is_memmap() 來輕鬆檢查 tensordict 是否已放入物理儲存。
建立記憶體對映張量本身有幾種方法。首先,可以簡單地建立一個空張量
>>> shape = torch.Size([3, 4, 5])
>>> tensor = Memmaptensor(*shape, prefix="/tmp")
>>> tensor[:2] = torch.randn(2, 4, 5)
prefix 屬性指示臨時檔案儲存的位置。至關重要的是,該張量必須儲存在每個節點都可以訪問的目錄中!
另一種選擇是將磁碟上的現有張量表示出來
>>> tensor = torch.randn(3)
>>> tensor = Memmaptensor(tensor, prefix="/tmp")
當張量很大或不適合放入記憶體時,將優先選擇前一種方法:它適用於非常大的張量,並作為節點之間的公共儲存。例如,可以建立一個數據集,以便單節點或不同節點都能輕鬆訪問,比載入每個檔案到記憶體中要快得多。
>>> dataset = TensorDict({
... "images": MemmapTensor(50000, 480, 480, 3),
... "masks": MemmapTensor(50000, 480, 480, 3, dtype=torch.bool),
... "labels": MemmapTensor(50000, 1, dtype=torch.uint8),
... }, batch_size=[50000], device="cpu")
>>> idx = [1, 5020, 34572, 11200]
>>> batch = dataset[idx].clone()
TensorDict(
fields={
images: Tensor(torch.Size([4, 480, 480, 3]), dtype=torch.float32),
labels: Tensor(torch.Size([4, 1]), dtype=torch.uint8),
masks: Tensor(torch.Size([4, 480, 480, 3]), dtype=torch.bool)},
batch_size=torch.Size([4]),
device=cpu,
is_shared=False)
請注意,我們已指定了 MemmapTensor 的裝置。這種語法糖允許在需要時將查詢到的張量直接載入到裝置上。
需要考慮的另一個問題是,目前 MemmapTensor 與自動微分(autograd)操作不相容。
跨節點操作記憶體對映張量¶
我們提供了一個簡單的分散式指令碼示例,其中一個程序建立一個記憶體對映張量,並將其引用傳送給另一個負責更新它的工作程序。您可以在 benchmark 目錄 中找到此示例。
簡而言之,我們的目標是展示在節點可以訪問共享物理儲存時,如何處理大張量的讀寫操作。步驟包括:
在磁碟上建立空張量;
設定要執行的本地和遠端操作;
使用 RPC 在工作程序之間傳遞命令,以讀取和寫入共享資料。
該示例首先編寫一個函式,該函式使用填充了 1 的張量來更新特定索引處的 TensorDict 例項。
>>> def fill_tensordict(tensordict, idx):
... tensordict[idx] = TensorDict(
... {"memmap": torch.ones(5, 640, 640, 3, dtype=torch.uint8)}, [5]
... )
... return tensordict
>>> fill_tensordict_cp = CloudpickleWrapper(fill_tensordict)
CloudpickleWrapper 確保該函式是可序列化的。接下來,我們建立一個相當大的 tensordict,以此說明如果必須透過常規的 tensorpipe 傳遞,它將很難從一個工作程序傳遞到另一個工作程序。
>>> tensordict = TensorDict(
... {"memmap": MemmapTensor(1000, 640, 640, 3, dtype=torch.uint8, prefix="/tmp/")}, [1000]
... )
最後,仍然在主節點上,我們在 *遠端節點* 上呼叫該函式,然後檢查資料是否已寫入所需位置。
>>> idx = [4, 5, 6, 7, 998]
>>> t0 = time.time()
>>> out = rpc.rpc_sync(
... worker_info,
... fill_tensordict_cp,
... args=(tensordict, idx),
... )
>>> print("time elapsed:", time.time() - t0)
>>> print("check all ones", out["memmap"][idx, :1, :1, :1].clone())
儘管呼叫 rpc.rpc_sync 涉及傳遞整個 tensordict,更新該物件的特定索引並將其返回給原始工作程序,但該程式碼片段的執行速度非常快(如果記憶體位置的引用已預先傳遞,速度會更快,請參閱 torchrl 的分散式回放緩衝區文件 瞭解更多資訊)。
該指令碼包含額外的 RPC 配置步驟,超出了本文件的範圍。