注意
轉到末尾 下載完整的示例程式碼。
TensorDictModule¶
作者:Nicolas Dufour, Vincent Moens
在本教程中,您將學習如何使用 TensorDictModule 和 TensorDictSequential 建立通用且可重用的模組,這些模組可以接受 TensorDict 作為輸入。
為了方便使用 TensorDict 類與 Module,tensordict 提供了兩者之間的介面,名為 TensorDictModule。
TensorDictModule 類是一個 Module,在呼叫時它接受一個 TensorDict 作為輸入。它將讀取一系列輸入鍵,將它們作為輸入傳遞給包裝的模組或函式,並在執行完成後將輸出寫入同一個 tensordict 中。
由使用者定義要讀取的輸入和輸出鍵。
import torch
import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential
簡單示例:編碼迴圈層¶
TensorDictModule 的最簡單用法如下例所示。如果一開始使用此類似乎引入了不必要的複雜性,我們稍後會看到,此 API 使使用者能夠以程式設計方式將模組連線在一起,在模組之間快取值或以程式設計方式構建模組。其中一個最簡單的例子是類似 ResNet 的架構中的迴圈模組,其中模組的輸入被快取並新增到小型多層感知器 (MLP) 的輸出中。
首先,讓我們考慮一下如何將 MLP 分塊,並使用 tensordict.nn 進行編碼。堆疊的第一層可能是一個 Linear 層,它接受一個條目(我們稱之為 x)作為輸入,並輸出另一個條目(我們稱之為 y)。
為了饋送我們的模組,我們有一個包含單個條目 "x" 的 TensorDict 例項
tensordict = TensorDict(
x=torch.randn(5, 3),
batch_size=[5],
)
現在,我們使用 tensordict.nn.TensorDictModule 構建我們的簡單模組。預設情況下,此類會在輸入 tensordict 中原地寫入(這意味著條目寫入與輸入相同的 tensordict,而不是原地覆蓋條目!),因此我們無需顯式指示輸出是什麼。
linear0 = TensorDictModule(nn.Linear(3, 128), in_keys=["x"], out_keys=["linear0"])
linear0(tensordict)
assert "linear0" in tensordict
如果模組輸出多個張量(或 tensordict!),則必須按正確的順序將它們的條目傳遞給 TensorDictModule。
支援可呼叫物件¶
在設計模型時,經常需要將任意的非引數函式納入網路。例如,您可能希望在影像傳遞到卷積網路或視覺 Transformer 時對其維度進行排列,或者將值除以 255。有幾種方法可以做到這一點:您可以使用 forward_hook,或者設計一個新的執行此操作的 Module。
TensorDictModule 可與任何可呼叫物件配合使用,而不僅僅是模組,這使得將任意函式整合到模組中變得容易。例如,讓我們看看如何在不使用 ReLU 模組的情況下整合 relu 啟用函式。
relu0 = TensorDictModule(torch.relu, in_keys=["linear0"], out_keys=["relu0"])
堆疊模組¶
我們的 MLP 不僅僅是一個層,所以我們需要再新增一個層。這一層將是一個啟用函式,例如 ReLU。我們可以使用 TensorDictSequential 將此模組和前一個模組堆疊起來。
注意
這就是 tensordict.nn 的真正威力所在:與 Sequential 不同,TensorDictSequential 將會記住所有之前的輸入和輸出(並可以選擇之後過濾掉它們),從而可以輕鬆地即時以程式設計方式構建複雜的網路結構。
block0 = TensorDictSequential(linear0, relu0)
block0(tensordict)
assert "linear0" in tensordict
assert "relu0" in tensordict
我們可以重複這個邏輯來構建一個完整的 MLP。
linear1 = TensorDictModule(nn.Linear(128, 128), in_keys=["relu0"], out_keys=["linear1"])
relu1 = TensorDictModule(nn.ReLU(), in_keys=["linear1"], out_keys=["relu1"])
linear2 = TensorDictModule(nn.Linear(128, 3), in_keys=["relu1"], out_keys=["linear2"])
block1 = TensorDictSequential(linear1, relu1, linear2)
多個輸入鍵¶
殘差網路的最後一步是將輸入新增到最後一個線性層的輸出。不需要為此編寫特殊的 Module 子類! TensorDictModule 也可以用於包裝簡單的函式。
residual = TensorDictModule(
lambda x, y: x + y, in_keys=["x", "linear2"], out_keys=["y"]
)
現在我們可以將 block0、block1 和 residual 組合起來,構建一個完整的殘差塊。
block = TensorDictSequential(block0, block1, residual)
block(tensordict)
assert "y" in tensordict
一個真正的擔憂可能是 tensordict 中條目的累積:在某些情況下(例如,當需要梯度時)中間值可能會被快取,但這並非總是如此,而且讓垃圾回收器知道某些條目可以被丟棄可能很有用。 tensordict.nn.TensorDictModuleBase 及其子類(包括 tensordict.nn.TensorDictModule 和 tensordict.nn.TensorDictSequential)可以選擇在執行後過濾其輸出鍵。為此,只需呼叫 tensordict.nn.TensorDictModuleBase.select_out_keys 方法。這將原地更新模組,所有不需要的條目都將被丟棄。
block.select_out_keys("y")
tensordict = TensorDict(x=torch.randn(1, 3), batch_size=[1])
block(tensordict)
assert "y" in tensordict
assert "linear1" not in tensordict
然而,輸入鍵被保留。
assert "x" in tensordict
順便說一句,selected_out_keys 也可以傳遞給 tensordict.nn.TensorDictSequential,以避免單獨呼叫此方法。
不使用 tensordict 的 TensorDictModule¶
tensordict.nn.TensorDictSequential 提供的即時構建複雜架構的機會,並不意味著必須切換到 tensordict 來表示資料。得益於 dispatch,tensordict.nn 中的模組支援與條目名稱匹配的引數和關鍵字引數。
x = torch.randn(1, 3)
y = block(x=x)
assert isinstance(y, torch.Tensor)
在底層,dispatch 會重新構建一個 tensordict,執行模組,然後將其解構。這可能會導致一些開銷,但正如我們接下來將看到的,有一個解決方案可以消除這一點。
執行時¶
tensordict.nn.TensorDictModule 和 tensordict.nn.TensorDictSequential 在執行時會產生一些開銷,因為它們需要讀寫 tensordict。但是,我們可以透過使用 compile() 來大大減少這種開銷。為此,讓我們比較一下帶或不帶 compile 的此程式碼的三種版本。
class ResidualBlock(nn.Module):
def __init__(self):
super().__init__()
self.linear0 = nn.Linear(3, 128)
self.relu0 = nn.ReLU()
self.linear1 = nn.Linear(128, 128)
self.relu1 = nn.ReLU()
self.linear2 = nn.Linear(128, 3)
def forward(self, x):
y = self.linear0(x)
y = self.relu0(y)
y = self.linear1(y)
y = self.relu1(y)
return self.linear2(y) + x
print("Without compile")
x = torch.randn(256, 3)
block_notd = ResidualBlock()
block_tdm = TensorDictModule(block_notd, in_keys=["x"], out_keys=["y"])
block_tds = block
from torch.utils.benchmark import Timer
print(
f"Regular: {Timer('block_notd(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
print(
f"TDM: {Timer('block_tdm(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
print(
f"Sequential: {Timer('block_tds(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
print("Compiled versions")
block_notd_c = torch.compile(block_notd, mode="reduce-overhead")
for _ in range(5): # warmup
block_notd_c(x)
print(
f"Compiled regular: {Timer('block_notd_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
block_tdm_c = torch.compile(block_tdm, mode="reduce-overhead")
for _ in range(5): # warmup
block_tdm_c(x=x)
print(
f"Compiled TDM: {Timer('block_tdm_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
block_tds_c = torch.compile(block_tds, mode="reduce-overhead")
for _ in range(5): # warmup
block_tds_c(x=x)
print(
f"Compiled sequential: {Timer('block_tds_c(x=x)', globals=globals()).adaptive_autorange().median * 1_000_000: 4.4f} us"
)
Without compile
Regular: 215.3987 us
TDM: 280.3646 us
Sequential: 503.1584 us
Compiled versions
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 200, in forward
y = self.linear0(x)
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 201, in forward
y = self.relu0(y)
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 202, in forward
y = self.linear1(y)
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 203, in forward
y = self.relu1(y)
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 204, in forward
return self.linear2(y) + x
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 204, in forward
return self.linear2(y) + x
Compiled regular: 374.3750 us
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
out = func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
tensors_out = self._call_module(tensors, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
out = self.module(*tensors, **kwargs)
File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 200, in forward
y = self.linear0(x)
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
out = func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
tensors_out = self._call_module(tensors, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
out = self.module(*tensors, **kwargs)
File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 201, in forward
y = self.relu0(y)
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
out = func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
tensors_out = self._call_module(tensors, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
out = self.module(*tensors, **kwargs)
File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 202, in forward
y = self.linear1(y)
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
out = func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
tensors_out = self._call_module(tensors, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
out = self.module(*tensors, **kwargs)
File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 203, in forward
y = self.relu1(y)
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
out = func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
tensors_out = self._call_module(tensors, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
out = self.module(*tensors, **kwargs)
File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 204, in forward
return self.linear2(y) + x
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
out = func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
tensors_out = self._call_module(tensors, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
out = self.module(*tensors, **kwargs)
File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 204, in forward
return self.linear2(y) + x
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
out = func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
tensors_out = self._call_module(tensors, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
out = self.module(*tensors, **kwargs)
File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 204, in forward
return self.linear2(y) + x
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
out = func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
tensors_out = self._call_module(tensors, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
out = self.module(*tensors, **kwargs)
File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 204, in forward
return self.linear2(y) + x
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
out = func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
tensors_out = self._call_module(tensors, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
out = self.module(*tensors, **kwargs)
File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 204, in forward
return self.linear2(y) + x
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
out = func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
tensors_out = self._call_module(tensors, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
out = self.module(*tensors, **kwargs)
File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 203, in forward
y = self.relu1(y)
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
out = func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
tensors_out = self._call_module(tensors, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
out = self.module(*tensors, **kwargs)
File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 202, in forward
y = self.linear1(y)
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
out = func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
tensors_out = self._call_module(tensors, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
out = self.module(*tensors, **kwargs)
File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 201, in forward
y = self.relu0(y)
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
out = func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
tensors_out = self._call_module(tensors, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
out = self.module(*tensors, **kwargs)
File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 202, in forward
y = self.linear1(y)
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
out = func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
tensors_out = self._call_module(tensors, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
out = self.module(*tensors, **kwargs)
File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 202, in forward
y = self.linear1(y)
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
out = func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
tensors_out = self._call_module(tensors, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
out = self.module(*tensors, **kwargs)
File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 200, in forward
y = self.linear0(x)
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
out = func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
tensors_out = self._call_module(tensors, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
out = self.module(*tensors, **kwargs)
File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 200, in forward
y = self.linear0(x)
Compiled TDM: 408.7855 us
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
out = func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 633, in forward
tensordict_exec = self._run_module(
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 579, in _run_module
tensordict = module(tensordict, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 633, in forward
tensordict_exec = self._run_module(
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 579, in _run_module
tensordict = module(tensordict, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
tensors_out = self._call_module(tensors, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
out = self.module(*tensors, **kwargs)
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
out = func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 633, in forward
tensordict_exec = self._run_module(
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 579, in _run_module
tensordict = module(tensordict, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 633, in forward
tensordict_exec = self._run_module(
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 579, in _run_module
tensordict = module(tensordict, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
tensors_out = self._call_module(tensors, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
out = self.module(*tensors, **kwargs)
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
out = func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 633, in forward
tensordict_exec = self._run_module(
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 579, in _run_module
tensordict = module(tensordict, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 633, in forward
tensordict_exec = self._run_module(
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 579, in _run_module
tensordict = module(tensordict, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
tensors_out = self._call_module(tensors, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
out = self.module(*tensors, **kwargs)
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
out = func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 633, in forward
tensordict_exec = self._run_module(
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 579, in _run_module
tensordict = module(tensordict, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 633, in forward
tensordict_exec = self._run_module(
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 579, in _run_module
tensordict = module(tensordict, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
tensors_out = self._call_module(tensors, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
out = self.module(*tensors, **kwargs)
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
out = func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 633, in forward
tensordict_exec = self._run_module(
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 579, in _run_module
tensordict = module(tensordict, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 633, in forward
tensordict_exec = self._run_module(
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 579, in _run_module
tensordict = module(tensordict, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
tensors_out = self._call_module(tensors, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
out = self.module(*tensors, **kwargs)
cudagraph partition due to non gpu ops. Found from :
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 319, in wrapper
out = func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 633, in forward
tensordict_exec = self._run_module(
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/sequence.py", line 579, in _run_module
tensordict = module(tensordict, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 328, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py", line 369, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1174, in forward
tensors_out = self._call_module(tensors, **kwargs)
File "/pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py", line 1133, in _call_module
out = self.module(*tensors, **kwargs)
File "/pytorch/tensordict/docs/source/reference/generated/tutorials/tensordict_module.py", line 131, in <lambda>
lambda x, y: x + y, in_keys=["x", "linear2"], out_keys=["y"]
Compiled sequential: 377.9500 us
如您所見,TensorDictSequential 引入的開銷已完全消除。
使用 TensorDictModule 的注意事項¶
不要在
tensordict.nn中的模組周圍使用Sequence。這會破壞輸入/輸出鍵結構。始終嘗試依賴nn:TensorDictSequential。不要將輸出 tensordict 分配給新變數,因為輸出 tensordict 只是原地修改的輸入。分配新變數名並非嚴格禁止,但這可能意味著您希望當一個變數被刪除時,兩個變數都消失,而實際上垃圾回收器仍然可以看到工作空間中的張量,並且不會釋放任何記憶體。
>>> tensordict = module(tensordict) # ok! >>> tensordict_out = module(tensordict) # don't!
處理分佈:ProbabilisticTensorDictModule¶
ProbabilisticTensorDictModule 是一個表示機率分佈的非引數模組。分佈引數從 tensordict 輸入中讀取,輸出寫入輸出 tensordict。根據由輸入 default_interaction_type 引數和 interaction_type() 全域性函式指定的規則,給定這些引數進行取樣。如果它們發生衝突,上下文管理器優先。
它可以與返回已更新分佈引數的 tensordict 的 TensorDictModule 一起使用,透過 ProbabilisticTensorDictSequential 進行連線。這是 TensorDictSequential 的一個特例,其最後一層是 ProbabilisticTensorDictModule 例項。
ProbabilisticTensorDictModule 負責構建分佈(透過 get_dist() 方法)和/或從該分佈取樣(透過模組的常規 forward 呼叫)。相同的 get_dist() 方法在 ProbabilisticTensorDictSequential 中公開。
可以在輸出 tensordict 中找到引數,如果需要,還可以找到對數機率。
from tensordict.nn import (
ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
)
from tensordict.nn.distributions import NormalParamExtractor
from torch import distributions as dist
td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3])
net = torch.nn.GRUCell(4, 8)
net = TensorDictModule(net, in_keys=["input", "hidden"], out_keys=["hidden"])
extractor = NormalParamExtractor()
extractor = TensorDictModule(extractor, in_keys=["hidden"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictSequential(
net,
extractor,
ProbabilisticTensorDictModule(
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=dist.Normal,
return_log_prob=True,
),
)
print(f"TensorDict before going through module: {td}")
td_module(td)
print(f"TensorDict after going through module now as keys action, loc and scale: {td}")
TensorDict before going through module: TensorDict(
fields={
hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
TensorDict after going through module now as keys action, loc and scale: TensorDict(
fields={
action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
action_log_prob: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False),
input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False),
scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3]),
device=None,
is_shared=False)
結論¶
我們已經看到了 tensordict.nn 如何用於動態地即時構建複雜的神經網路架構。這為構建模型簽名無關的管道提供了可能性,即編寫通用的程式碼,以靈活的方式使用具有任意數量輸入或輸出的網路。
我們還看到了 dispatch 如何使 tensordict.nn 能夠構建此類網路並無需依賴 TensorDict 直接使用它們。得益於 compile(),tensordict.nn.TensorDictSequential 引入的開銷可以完全消除,使使用者獲得一個整潔、無 tensordict 的模組版本。
在下一個教程中,我們將看到 torch.export 如何用於隔離和匯出模組。
指令碼總執行時間: (0 分鐘 13.884 秒)