概述¶
TensorDict 使組織資料和編寫可重用、通用的 PyTorch 程式碼變得容易。它最初是為 TorchRL 開發的,現在已作為獨立庫釋出。
TensorDict 主要是一個字典,同時也是一個類似張量的類:它支援多種主要與形狀和儲存相關的張量操作。它旨在能夠有效地序列化或從節點到節點、程序到程序傳輸。最後,它還附帶了自己的 nn 模組,該模組與 torch.func 相容,旨在簡化模型整合和引數操作。
在本頁面,我們將介紹 TensorDict 的動機,並提供一些它功能的示例。
動機¶
TensorDict 允許您編寫可在不同正規化之間重用的通用程式碼模組。例如,以下迴圈可以用於大多數 SL、SSL、UL 和 RL 任務。
>>> for i, tensordict in enumerate(dataset):
... # the model reads and writes tensordicts
... tensordict = model(tensordict)
... loss = loss_module(tensordict)
... loss.backward()
... optimizer.step()
... optimizer.zero_grad()
透過其 nn 模組,該包提供了許多工具,可以輕鬆地在程式碼庫中使用 TensorDict。
在多程序或分散式環境中,TensorDict 允許您無縫地將資料分發給每個工作程序。
>>> # creates batches of 10 datapoints
>>> splits = torch.arange(tensordict.shape[0]).split(10)
>>> for worker in range(workers):
... idx = splits[worker]
... pipe[worker].send(tensordict[idx])
TensorDict 提供的一些操作也可以透過 tree_map 完成,但複雜度更高。
>>> td = TensorDict(
... {"a": torch.randn(3, 11), "b": torch.randn(3, 3)}, batch_size=3
... )
>>> regular_dict = {"a": td["a"], "b": td["b"]}
>>> td0, td1, td2 = td.unbind(0)
>>> # similar structure with pytree
>>> regular_dicts = tree_map(lambda x: x.unbind(0))
>>> regular_dict1, regular_dict2, regular_dict3 = [
... {"a": regular_dicts["a"][i], "b": regular_dicts["b"][i]}
... for i in range(3)]
巢狀情況更具說服力。
>>> td = TensorDict(
... {"a": {"c": torch.randn(3, 11)}, "b": torch.randn(3, 3)}, batch_size=3
... )
>>> regular_dict = {"a": {"c": td["a", "c"]}, "b": td["b"]}
>>> td0, td1, td2 = td.unbind(0)
>>> # similar structure with pytree
>>> regular_dicts = tree_map(lambda x: x.unbind(0))
>>> regular_dict1, regular_dict2, regular_dict3 = [
... {"a": {"c": regular_dicts["a"]["c"][i]}, "b": regular_dicts["b"][i]}
... for i in range(3)
在對 pytree 進行原始操作時,在應用 unbind 操作後將輸出字典分解為三個結構相似的字典,會變得非常麻煩。使用 tensordict,我們為想要 unbind 或拆分巢狀結構的使用者的提供了簡單的 API,而不是計算巢狀的拆分/unbind 巢狀結構。
特性¶
一個 TensorDict 是一個類似字典的張量容器。要例項化一個 TensorDict,您可以指定鍵值對以及批次大小(可以透過 TensorDict() 建立一個空的 tensordict)。TensorDict 中任何值的領先維度必須與批次大小相容。
>>> import torch
>>> from tensordict import TensorDict
>>> tensordict = TensorDict(
... {"zeros": torch.zeros(2, 3, 4), "ones": torch.ones(2, 3, 4, 5)},
... batch_size=[2, 3],
... )
設定或檢索值的語法與常規字典非常相似。
>>> zeros = tensordict["zeros"]
>>> tensordict["twos"] = 2 * torch.ones(2, 3)
還可以沿其 batch_size 索引一個 tensordict,這使得只需幾個字元即可獲得資料的同等切片(請注意,使用 ellipsis 和 tree_map 索引第 n 個領先維度需要更多的編碼)。
>>> sub_tensordict = tensordict[..., :2]
還可以使用 inplace=True 的 set 方法或 set_() 方法進行原地更新。前者是後者的容錯版本:如果找不到匹配的鍵,它將寫入一個新鍵。
現在可以集體操作 TensorDict 的內容。例如,要將所有內容放置到特定裝置,只需執行以下操作:
>>> tensordict = tensordict.to("cuda:0")
然後,您可以斷言 tensordict 的裝置是 “cuda:0”。
>>> assert tensordict.device == torch.device("cuda:0")
要重塑批次維度,您可以執行以下操作:
>>> tensordict = tensordict.reshape(6)
該類還支援許多其他操作,包括 squeeze()、unsqueeze()、view()、permute()、unbind()、stack()、cat() 等等。
如果缺少某個操作,apply() 方法通常會提供所需解決方案。
規避形狀操作¶
在某些情況下,可能希望在不強制形狀操作期間批次大小一致性的情況下將張量儲存在 TensorDict 中。
這可以透過將張量包裝在 UnbatchedTensor 例項中來實現。
UnbatchedTensor 在 TensorDict 的形狀操作期間會忽略其形狀,從而允許靈活地儲存和操作任意形狀的張量。
>>> from tensordict import UnbatchedTensor
>>> tensordict = TensorDict({"zeros": UnbatchedTensor(torch.zeros(10))}, batch_size=[2, 3])
>>> reshaped_td = tensordict.reshape(6)
>>> reshaped_td["zeros"] is tensordict["zeros"]
True
非張量資料¶
Tensordict 是一個強大的處理張量資料的庫,但也支援非張量資料。本指南將向您展示如何使用 tensordict 處理非張量資料。
使用非張量資料建立 TensorDict¶
您可以使用 NonTensorData 類建立包含非張量資料的 TensorDict。
>>> from tensordict import TensorDict, NonTensorData
>>> import torch
>>> td = TensorDict(
... a=NonTensorData("a string!"),
... b=torch.zeros(()),
... )
>>> print(td)
TensorDict(
fields={
a: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None),
b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
如您所見,NonTensorData 物件就像普通張量一樣儲存在 TensorDict 中。
MetaData 類可用於承載不可索引的資料,或不需要遵循 tensordict 批次大小的資料。
訪問非張量資料¶
您可以使用鍵或 get 方法訪問非張量資料。常規的 getattr 呼叫將返回 NonTensorData 物件的內容,而 get() 將返回 NonTensorData 物件本身。
>>> print(td["a"]) # prints: a string!
>>> print(td.get("a")) # prints: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None)
批處理的非張量資料¶
如果您有一個非張量資料的批次,您可以將其儲存在具有指定批次大小的 TensorDict 中。
>>> td = TensorDict(
... a=NonTensorData("a string!"),
... b=torch.zeros(3),
... batch_size=[3]
... )
>>> print(td)
TensorDict(
fields={
a: NonTensorData(data=a string!, batch_size=torch.Size([3]), device=None),
b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
在這種情況下,我們假設 tensordict 的所有元素都具有相同的非張量資料。
>>> print(td[0])
TensorDict(
fields={
a: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None),
b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
要為已排序 tensordict 中的每個元素分配不同的非張量資料物件,可以使用非張量資料堆疊。
堆疊的非張量資料¶
如果您有一個要儲存在 TensorDict 中的非張量資料列表,您可以使用 NonTensorStack 類。
>>> td = TensorDict(
... a=NonTensorStack("a string!", "another string!", "a third string!"),
... b=torch.zeros(3),
... batch_size=[3]
... )
>>> print(td)
TensorDict(
fields={
a: NonTensorStack(
['a string!', 'another string!', 'a third string!'...,
batch_size=torch.Size([3]),
device=None),
b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
您可以訪問第一個元素,您將獲得第一個字串。
>>> print(td[0])
TensorDict(
fields={
a: NonTensorData(data=a string!, batch_size=torch.Size([]), device=None),
b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
相比之下,將 NonTensorData 與列表一起使用不會得到相同的結果,因為無法確定如何普遍處理恰好是列表的非張量資料。
>>> td = TensorDict(
... a=NonTensorData(["a string!", "another string!", "a third string!"]),
... b=torch.zeros(3),
... batch_size=[3]
... )
>>> print(td[0])
TensorDict(
fields={
a: NonTensorData(data=['a string!', 'another string!', 'a third string!'], batch_size=torch.Size([]), device=None),
b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
堆疊帶有非張量資料的 TensorDict¶
為了堆疊非張量資料,stack() 將建立一個 NonTensorStack。相比之下,在使用 MetaData 例項時,如果內容匹配,堆疊操作將生成單個 MetaData 例項。
>>> td = TensorDict(
... a=NonTensorData("a string!"),
... b = torch.zeros(()),
... )
>>> print(torch.stack([td, td]))
TensorDict(
fields={
a: NonTensorStack(
['a string!', 'a string!'],
batch_size=torch.Size([2]),
device=None),
b: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
>>> td = TensorDict(
... a=MetaData("a string!"),
... b = torch.zeros(()),
... )
>>> print(torch.stack([td, td]))
TensorDict(
fields={
a: MetaData(data=a string!, batch_size=torch.Size([2]), device=None),
b: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
命名維度¶
TensorDict 和相關類也支援維度名稱。名稱可以在構造時給出,也可以稍後進行細化。其語義與 torch.Tensor 的維度名稱功能類似。
>>> tensordict = TensorDict({}, batch_size=[3, 4], names=["a", None])
>>> tensordict.refine_names(..., "b")
>>> tensordict.names = ["z", "y"]
>>> tensordict.rename("m", "n")
>>> tensordict.rename(m="h")
巢狀 TensorDicts¶
TensorDict 中的值本身可以是 TensorDicts(下面示例中的巢狀字典將被轉換為巢狀 TensorDicts)。
>>> tensordict = TensorDict(
... {
... "inputs": {
... "image": torch.rand(100, 28, 28),
... "mask": torch.randint(2, (100, 28, 28), dtype=torch.uint8)
... },
... "outputs": {"logits": torch.randn(100, 10)},
... },
... batch_size=[100],
... )
可以透過字串元組訪問或設定巢狀鍵。
>>> image = tensordict["inputs", "image"]
>>> logits = tensordict.get(("outputs", "logits")) # alternative way to access
>>> tensordict["outputs", "probabilities"] = torch.sigmoid(logits)
延遲評估¶
TensorDict 上的一些操作會推遲執行,直到訪問專案。例如,堆疊、擠壓、擴充套件、置換批次維度和建立檢視不會立即在 TensorDict 的所有內容上執行。相反,它們會在訪問 TensorDict 中的值時惰性執行。如果 TensorDict 包含許多值,這可以節省大量不必要的計算。
>>> tensordicts = [TensorDict({
... "a": torch.rand(10),
... "b": torch.rand(10, 1000, 1000)}, [10])
... for _ in range(3)]
>>> stacked = torch.stack(tensordicts, 0) # no stacking happens here
>>> stacked_a = stacked["a"] # we stack the a values, b values are not stacked
它還有一個優點是我們可以操縱堆疊中的原始 tensordicts。
>>> stacked["a"] = torch.zeros_like(stacked["a"])
>>> assert (tensordicts[0]["a"] == 0).all()
需要注意的是,get 方法現在變成了一個昂貴的操作,如果重複多次,可能會導致一些開銷。可以透過在 stack 執行後簡單呼叫 tensordict.contiguous() 來避免此問題。為了進一步緩解此問題,TensorDict 提供了自己的元資料類(MetaTensor),該類跟蹤字典每個條目的型別、形狀、dtype 和裝置,而無需執行昂貴的操作。
延遲預分配¶
假設我們有一個函式 foo() -> TensorDict,並且我們執行以下操作:
>>> tensordict = TensorDict({}, batch_size=[N])
>>> for i in range(N):
... tensordict[i] = foo()
當 i == 0 時,空的 TensorDict 將自動填充具有批次大小 N 的空張量。在迴圈的後續迭代中,所有更新都將原地寫入。
TensorDictModule¶
為了便於將 TensorDict 整合到程式碼庫中,我們提供了 tensordict.nn 包,允許使用者將 TensorDict 例項傳遞給 Module 物件(或任何可呼叫物件)。
TensorDictModule 包裝了 Module,並接受一個 TensorDict 作為輸入。您可以指定底層模組應從何處獲取輸入,以及應將輸出寫入何處。這是我們能夠編寫可重用的、通用的高階程式碼(如動機部分中的訓練迴圈)的關鍵原因。
>>> from tensordict.nn import TensorDictModule
>>> class Net(nn.Module):
... def __init__(self):
... super().__init__()
... self.linear = nn.LazyLinear(1)
...
... def forward(self, x):
... logits = self.linear(x)
... return logits, torch.sigmoid(logits)
>>> module = TensorDictModule(
... Net(),
... in_keys=["input"],
... out_keys=[("outputs", "logits"), ("outputs", "probabilities")],
... )
>>> tensordict = TensorDict({"input": torch.randn(32, 100)}, [32])
>>> tensordict = module(tensordict)
>>> # outputs can now be retrieved from the tensordict
>>> logits = tensordict["outputs", "logits"]
>>> probabilities = tensordict.get(("outputs", "probabilities"))
為了方便採用此類,您還可以將張量作為 kwargs 傳遞。
>>> tensordict = module(input=torch.randn(32, 100))
這將返回一個與上一個程式碼框中的 TensorDict 相同的 TensorDict。有關此功能的更多背景資訊,請參閱 匯出教程。
許多 PyTorch 使用者面臨的一個主要痛點是 nn.Sequential 無法處理具有多個輸入的模組。使用基於鍵的圖可以輕鬆解決此問題,因為序列中的每個節點都知道需要讀取哪些資料以及寫入何處。
為此,我們提供了 TensorDictSequential 類,它將資料傳遞給一系列 TensorDictModules。序列中的每個模組都從原始 TensorDict 獲取輸入,並將輸出寫入其中,這意味著序列中的模組可以忽略其前驅的輸出,或根據需要從 tensordict 獲取額外的輸入。這是一個例子:
>>> class Net(nn.Module):
... def __init__(self, input_size=100, hidden_size=50, output_size=10):
... super().__init__()
... self.fc1 = nn.Linear(input_size, hidden_size)
... self.fc2 = nn.Linear(hidden_size, output_size)
...
... def forward(self, x):
... x = torch.relu(self.fc1(x))
... return self.fc2(x)
...
... class Masker(nn.Module):
... def forward(self, x, mask):
... return torch.softmax(x * mask, dim=1)
>>> net = TensorDictModule(
... Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")]
... )
>>> masker = TensorDictModule(
... Masker(),
... in_keys=[("intermediate", "x"), ("input", "mask")],
... out_keys=[("output", "probabilities")],
... )
>>> module = TensorDictSequential(net, masker)
>>> tensordict = TensorDict(
... {
... "input": TensorDict(
... {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))},
... batch_size=[32],
... )
... },
... batch_size=[32],
... )
>>> tensordict = module(tensordict)
>>> intermediate_x = tensordict["intermediate", "x"]
>>> probabilities = tensordict["output", "probabilities"]
在此示例中,第二個模組將第一個模組的輸出與儲存在 TensorDict 的 (“inputs”, “mask”) 下的掩碼結合起來。
TensorDictSequential 提供了許多其他功能:可以透過查詢 in_keys 和 out_keys 屬性來訪問輸入和輸出鍵的列表。還可以透過使用所需的輸入和輸出鍵集查詢 select_subsequence() 來請求子圖。這將返回另一個 TensorDictSequential,其中僅包含滿足這些要求所必需的模組。 TensorDictModule 也相容 vmap() 和其他 torch.func 功能。