tensordict.nn 包¶
tensordict.nn 包使得在 ML 管道中靈活使用 TensorDict 成為可能。
由於 TensorDict 將程式碼的一部分轉換為基於鍵的結構,現在可以使用這些鍵作為鉤子來構建複雜的圖結構。基本構建塊是 TensorDictModule,它使用一組輸入和輸出鍵包裝一個 torch.nn.Module 例項。
>>> from torch.nn import Transformer
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> import torch
>>> module = TensorDictModule(Transformer(), in_keys=["feature", "target"], out_keys=["prediction"])
>>> data = TensorDict({"feature": torch.randn(10, 11, 512), "target": torch.randn(10, 11, 512)}, [10, 11])
>>> data = module(data)
>>> print(data)
TensorDict(
fields={
feature: Tensor(torch.Size([10, 11, 512]), dtype=torch.float32),
prediction: Tensor(torch.Size([10, 11, 512]), dtype=torch.float32),
target: Tensor(torch.Size([10, 11, 512]), dtype=torch.float32)},
batch_size=torch.Size([10, 11]),
device=None,
is_shared=False)
不一定需要使用 TensorDictModule,一個具有有序輸入和輸出鍵(分別稱為 module.in_keys 和 module.out_keys)的自定義 torch.nn.Module 就足夠了。
許多 PyTorch 使用者面臨的一個痛點是 nn.Sequential 無法處理具有多個輸入的模組。使用基於鍵的圖可以輕鬆解決這個問題,因為序列中的每個節點都知道需要讀取哪些資料以及將資料寫入何處。
為此,我們提供了 TensorDictSequential 類,該類將資料透過 TensorDictModules 的序列傳遞。序列中的每個模組都從原始 TensorDict 中獲取輸入,並將其輸出寫入 TensorDict,這意味著序列中的模組可以忽略其前驅的輸出,或者根據需要從 tensordict 獲取其他輸入。示例如下:
>>> from tensordict.nn import TensorDictSequential
>>> class Net(nn.Module):
... def __init__(self, input_size=100, hidden_size=50, output_size=10):
... super().__init__()
... self.fc1 = nn.Linear(input_size, hidden_size)
... self.fc2 = nn.Linear(hidden_size, output_size)
...
... def forward(self, x):
... x = torch.relu(self.fc1(x))
... return self.fc2(x)
...
>>> class Masker(nn.Module):
... def forward(self, x, mask):
... return torch.softmax(x * mask, dim=1)
...
>>> net = TensorDictModule(
... Net(), in_keys=[("input", "x")], out_keys=[("intermediate", "x")]
... )
>>> masker = TensorDictModule(
... Masker(),
... in_keys=[("intermediate", "x"), ("input", "mask")],
... out_keys=[("output", "probabilities")],
... )
>>> module = TensorDictSequential(net, masker)
>>>
>>> td = TensorDict(
... {
... "input": TensorDict(
... {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))},
... batch_size=[32],
... )
... },
... batch_size=[32],
... )
>>> td = module(td)
>>> print(td)
TensorDict(
fields={
input: TensorDict(
fields={
mask: Tensor(torch.Size([32, 10]), dtype=torch.int64),
x: Tensor(torch.Size([32, 100]), dtype=torch.float32)},
batch_size=torch.Size([32]),
device=None,
is_shared=False),
intermediate: TensorDict(
fields={
x: Tensor(torch.Size([32, 10]), dtype=torch.float32)},
batch_size=torch.Size([32]),
device=None,
is_shared=False),
output: TensorDict(
fields={
probabilities: Tensor(torch.Size([32, 10]), dtype=torch.float32)},
batch_size=torch.Size([32]),
device=None,
is_shared=False)},
batch_size=torch.Size([32]),
device=None,
is_shared=False)
我們還可以透過 select_subsequence() 方法輕鬆選擇子圖。
>>> sub_module = module.select_subsequence(out_keys=[("intermediate", "x")])
>>> td = TensorDict(
... {
... "input": TensorDict(
... {"x": torch.rand(32, 100), "mask": torch.randint(2, size=(32, 10))},
... batch_size=[32],
... )
... },
... batch_size=[32],
... )
>>> sub_module(td)
>>> print(td) # the "output" has not been computed
TensorDict(
fields={
input: TensorDict(
fields={
mask: Tensor(torch.Size([32, 10]), dtype=torch.int64),
x: Tensor(torch.Size([32, 100]), dtype=torch.float32)},
batch_size=torch.Size([32]),
device=None,
is_shared=False),
intermediate: TensorDict(
fields={
x: Tensor(torch.Size([32, 10]), dtype=torch.float32)},
batch_size=torch.Size([32]),
device=None,
is_shared=False)},
batch_size=torch.Size([32]),
device=None,
is_shared=False)
最後,tensordict.nn 提供了一個 ProbabilisticTensorDictModule,允許根據網路輸出來構建分佈,並從中獲取摘要統計資訊或樣本(以及分佈引數)。
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> from tensordict.nn.distributions import NormalParamExtractor
>>> from tensordict.nn.prototype import (
... ProbabilisticTensorDictModule,
... ProbabilisticTensorDictSequential,
... )
>>> from torch.distributions import Normal
>>> td = TensorDict(
... {"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3]
... )
>>> net = torch.nn.Sequential(torch.nn.GRUCell(4, 8), NormalParamExtractor())
>>> module = TensorDictModule(
... net, in_keys=["input", "hidden"], out_keys=["loc", "scale"]
... )
>>> prob_module = ProbabilisticTensorDictModule(
... in_keys=["loc", "scale"],
... out_keys=["sample"],
... distribution_class=Normal,
... return_log_prob=True,
... )
>>> td_module = ProbabilisticTensorDictSequential(module, prob_module)
>>> td_module(td)
>>> print(td)
TensorDict(
fields={
action: Tensor(torch.Size([3, 4]), dtype=torch.float32),
hidden: Tensor(torch.Size([3, 8]), dtype=torch.float32),
input: Tensor(torch.Size([3, 4]), dtype=torch.float32),
loc: Tensor(torch.Size([3, 4]), dtype=torch.float32),
sample_log_prob: Tensor(torch.Size([3, 4]), dtype=torch.float32),
scale: Tensor(torch.Size([3, 4]), dtype=torch.float32)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
|
TensorDict 模組的基類。 |
|
TensorDictModule 是一個 python 包裝器,用於包裝一個讀取和寫入 TensorDict 的 |
|
機率性 TD 模組。 |
|
一個包含至少一個 |
|
TensorDictModules 的序列。 |
|
TensorDictModule 物件的包裝類。 |
|
PyTorch 可呼叫物件的 cudagraph 包裝器。 |
|
處理 TensorDict 例項的任何可呼叫物件的包裝器。 |
|
與分佈互動的可能型別列表。 |
|
設定所有 ProbabilisticTDModules 取樣到所需的型別。 |
|
控制 |
|
返回 |
|
將函式轉換為 TensorDictModule 的裝飾器。 |
整合¶
函式式方法使得實現簡單的整合成為可能。我們可以使用 tensordict.nn.EnsembleModule 來複制和重新初始化模型副本。
>>> import torch
>>> from torch import nn
>>> from tensordict.nn import TensorDictModule
>>> from torchrl.modules import EnsembleModule
>>> from tensordict import TensorDict
>>> net = nn.Sequential(nn.Linear(4, 32), nn.ReLU(), nn.Linear(32, 2))
>>> mod = TensorDictModule(net, in_keys=['a'], out_keys=['b'])
>>> ensemble = EnsembleModule(mod, num_copies=3)
>>> data = TensorDict({'a': torch.randn(10, 4)}, batch_size=[10])
>>> ensemble(data)
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([3, 10, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 10]),
device=None,
is_shared=False)
|
包裝一個模組並重復它以形成整合的模組。 |
編譯 TensorDictModules¶
自 v0.5 起,TensorDict 元件與 compile() 相容。例如,TensorDictSequential 模組可以使用 torch.compile 進行編譯,並達到與包裝在 TensorDictModule 中的常規 PyTorch 模組相似的執行時效能。
分佈¶
一個新增可訓練的、狀態獨立的尺度引數的 nn.Module。 |
|
|
一個使用 TensorDict 介面將多個分佈組合在一起的複合分佈。 |
|
Delta 分佈。 |
|
一個非引數的 nn.Module,將輸入分割為 loc 和 scale 引數。 |
|
獨熱(One-hot)分類分佈。 |
|
截斷正態分佈。 |
Utils¶
|
從關鍵字引數或輸入字典返回一個建立的 TensorDict。 |
|
允許使用 kwargs 呼叫期望 TensorDict 的函式。 |
|
反向 softplus 函式。 |
|
帶偏置的 softplus 模組。 |
|
一個用於在 TensorDict 圖中跳過現有節點的上下文管理器。 |
返回一個模組是否應該重新計算 tensordict 中的現有條目。 |
|
|
新增一個自定義對映,用於對映類。 |
|
給定一個輸入字串,返回一個滿射函式 f(x): R -> R^+。 |
|