注意
轉到末尾 下載完整的示例程式碼。
匯出 tensordict 模組¶
先決條件¶
最好先閱讀 TensorDictModule 教程,以充分受益於本教程。
一旦使用 tensordict.nn 編寫了一個模組,通常會希望隔離計算圖並匯出該圖。這樣做的目標可能是為了在硬體(例如,機器人、無人機、邊緣裝置)上執行模型,或者完全消除對 tensordict 的依賴。
PyTorch 提供了多種匯出模組的方法,包括 onnx 和 torch.export,它們都與 tensordict 相容。
在本簡短教程中,我們將瞭解如何使用 torch.export 來隔離模型的計算圖。torch.onnx 支援遵循相同的邏輯。
關鍵學習點¶
在沒有
TensorDict輸入的情況下執行tensordict.nn模組;選擇模型的輸出;
處理隨機模型;
使用 torch.export 匯出此類模型;
將模型儲存到檔案;
隔離 PyTorch 模型;
import time
import torch
from tensordict.nn import (
InteractionType,
NormalParamExtractor,
ProbabilisticTensorDictModule as Prob,
set_interaction_type,
TensorDictModule as Mod,
TensorDictSequential as Seq,
)
from torch import distributions as dists, nn
設計模型¶
在許多應用中,處理隨機模型很有用,即模型輸出一個變數,該變數不確定地定義,而是根據引數化分佈進行取樣。例如,生成式 AI 模型在提供相同輸入時通常會生成不同的輸出,因為它們根據引數由輸入定義的分佈來取樣輸出。
tensordict 庫透過 ProbabilisticTensorDictModule 類來處理這種情況。這個原語是使用一個分佈類(在我們的例子中是 Normal)和指示在執行時用於構建該分佈的輸入鍵的指示器構建的。
因此,我們正在構建的網路將是三個主要元件的組合:
一個將輸入對映到潛在引數的網路;
一個
tensordict.nn.NormalParamExtractor模組,它將輸入拆分為要傳遞給Normal分佈的“loc”和“scale”引數;一個分佈建構函式模組。
model = Seq(
# 1. A small network for embedding
Mod(nn.Linear(3, 4), in_keys=["x"], out_keys=["hidden"]),
Mod(nn.ReLU(), in_keys=["hidden"], out_keys=["hidden"]),
Mod(nn.Linear(4, 4), in_keys=["hidden"], out_keys=["latent"]),
# 2. Extracting params
Mod(NormalParamExtractor(), in_keys=["latent"], out_keys=["loc", "scale"]),
# 3. Probabilistic module
Prob(
in_keys=["loc", "scale"],
out_keys=["sample"],
distribution_class=dists.Normal,
),
)
讓我們執行這個模型,看看輸出是什麼樣的
x = torch.randn(1, 3)
print(model(x=x))
(tensor([[0.0978, 0.0400, 0.0000, 0.2323]], grad_fn=<ReluBackward0>), tensor([[-0.2766, 0.2466, -0.2790, 0.4099]], grad_fn=<AddmmBackward0>), tensor([[-0.2766, 0.2466]], grad_fn=<SplitBackward0>), tensor([[0.8339, 1.2764]], grad_fn=<ClampMinBackward0>), tensor([[-0.2766, 0.2466]], grad_fn=<SplitBackward0>))
正如預期的那樣,使用張量輸入執行模型會返回與模組的輸出鍵一樣多的張量!對於大型模型來說,這可能相當煩人和浪費。稍後,我們將看到如何限制模型的輸出數量來解決這個問題。
使用 torch.export 和 TensorDictModule¶
現在我們已經成功構建了模型,我們希望將其計算圖提取到一個獨立於 tensordict 的單個物件中。torch.export 是一個專門用於隔離模組圖並以標準化方式表示它的 PyTorch 模組。它的主要入口點是 export(),它返回一個 ExportedProgram 物件。反過來,該物件有幾個我們將在下面探討的有趣屬性:一個 graph_module,它表示 export 捕獲的 FX 圖;一個 graph_signature,包含圖的輸入、輸出等;最後,一個 module(),它返回一個可呼叫物件,可以替換原始模組。
雖然我們的模組接受 args 和 kwargs,但我們將重點關注它與 kwargs 的用法,因為這樣更清晰。
from torch.export import export
model_export = export(model, args=(), kwargs={"x": x}, strict=True)
我們來看看這個模組
print("module:", model_export.module())
module: GraphModule(
(module): Module(
(0): Module(
(module): Module()
)
(2): Module(
(module): Module()
)
)
)
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
module_0_module_weight = getattr(self.module, "0").module.weight
module_0_module_bias = getattr(self.module, "0").module.bias
module_2_module_weight = getattr(self.module, "2").module.weight
module_2_module_bias = getattr(self.module, "2").module.bias
linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); module_2_module_weight = module_2_module_bias = None
chunk = torch.ops.aten.chunk.default(linear_1, 2, -1)
getitem = chunk[0]
getitem_1 = chunk[1]; chunk = None
add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus = torch.ops.aten.softplus.default(add); add = None
add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2 = broadcast_tensors[0]
getitem_3 = broadcast_tensors[1]; broadcast_tensors = None
return pytree.tree_unflatten((relu, linear_1, getitem_2, getitem_3, getitem_2), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
此模組的執行方式與我們的原始模組完全相同(開銷較低)
Time for TDModule: 548.60 micro-seconds
Time for exported module: 689.98 micro-seconds
以及 FX 圖
print("fx graph:", model_export.graph_module.print_readable())
class GraphModule(torch.nn.Module):
def forward(self, p_l__args___0_module_0_module_weight: "f32[4, 3]", p_l__args___0_module_0_module_bias: "f32[4]", p_l__args___0_module_2_module_weight: "f32[4, 4]", p_l__args___0_module_2_module_bias: "f32[4]", x: "f32[1, 3]"):
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py:1133 in _call_module, code: out = self.module(*tensors, **kwargs)
linear: "f32[1, 4]" = torch.ops.aten.linear.default(x, p_l__args___0_module_0_module_weight, p_l__args___0_module_0_module_bias); x = p_l__args___0_module_0_module_weight = p_l__args___0_module_0_module_bias = None
relu: "f32[1, 4]" = torch.ops.aten.relu.default(linear); linear = None
linear_1: "f32[1, 4]" = torch.ops.aten.linear.default(relu, p_l__args___0_module_2_module_weight, p_l__args___0_module_2_module_bias); p_l__args___0_module_2_module_weight = p_l__args___0_module_2_module_bias = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/distributions/continuous.py:85 in forward, code: loc, scale = tensor.chunk(2, -1)
chunk = torch.ops.aten.chunk.default(linear_1, 2, -1)
getitem: "f32[1, 2]" = chunk[0]
getitem_1: "f32[1, 2]" = chunk[1]; chunk = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py:70 in forward, code: return torch.nn.functional.softplus(x + self.bias) + self.min_val
add: "f32[1, 2]" = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus: "f32[1, 2]" = torch.ops.aten.softplus.default(add); add = None
add_1: "f32[1, 2]" = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/distributions/continuous.py:86 in forward, code: scale = self.scale_mapping(scale).clamp_min(self.scale_lb)
clamp_min: "f32[1, 2]" = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/distributions/utils.py:58 in broadcast_all, code: return torch.broadcast_tensors(*values)
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2: "f32[1, 2]" = broadcast_tensors[0]
getitem_3: "f32[1, 2]" = broadcast_tensors[1]; broadcast_tensors = None
return (relu, linear_1, getitem_2, getitem_3, getitem_2)
fx graph: class GraphModule(torch.nn.Module):
def forward(self, p_l__args___0_module_0_module_weight: "f32[4, 3]", p_l__args___0_module_0_module_bias: "f32[4]", p_l__args___0_module_2_module_weight: "f32[4, 4]", p_l__args___0_module_2_module_bias: "f32[4]", x: "f32[1, 3]"):
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/common.py:1133 in _call_module, code: out = self.module(*tensors, **kwargs)
linear: "f32[1, 4]" = torch.ops.aten.linear.default(x, p_l__args___0_module_0_module_weight, p_l__args___0_module_0_module_bias); x = p_l__args___0_module_0_module_weight = p_l__args___0_module_0_module_bias = None
relu: "f32[1, 4]" = torch.ops.aten.relu.default(linear); linear = None
linear_1: "f32[1, 4]" = torch.ops.aten.linear.default(relu, p_l__args___0_module_2_module_weight, p_l__args___0_module_2_module_bias); p_l__args___0_module_2_module_weight = p_l__args___0_module_2_module_bias = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/distributions/continuous.py:85 in forward, code: loc, scale = tensor.chunk(2, -1)
chunk = torch.ops.aten.chunk.default(linear_1, 2, -1)
getitem: "f32[1, 2]" = chunk[0]
getitem_1: "f32[1, 2]" = chunk[1]; chunk = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py:70 in forward, code: return torch.nn.functional.softplus(x + self.bias) + self.min_val
add: "f32[1, 2]" = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus: "f32[1, 2]" = torch.ops.aten.softplus.default(add); add = None
add_1: "f32[1, 2]" = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/distributions/continuous.py:86 in forward, code: scale = self.scale_mapping(scale).clamp_min(self.scale_lb)
clamp_min: "f32[1, 2]" = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/distributions/utils.py:58 in broadcast_all, code: return torch.broadcast_tensors(*values)
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2: "f32[1, 2]" = broadcast_tensors[0]
getitem_3: "f32[1, 2]" = broadcast_tensors[1]; broadcast_tensors = None
return (relu, linear_1, getitem_2, getitem_3, getitem_2)
使用巢狀鍵¶
巢狀鍵是 tensordict 庫的核心功能,因此能夠匯出讀取和寫入巢狀條目的模組是一項重要的支援功能。由於關鍵字引數必須是常規字串,因此 dispatch 無法直接使用它們。相反,dispatch 將解包用常規下劃線(“_”)連線的巢狀鍵,如下例所示。
model_nested = Seq(
Mod(lambda x: x + 1, in_keys=[("some", "key")], out_keys=["hidden"]),
Mod(lambda x: x - 1, in_keys=["hidden"], out_keys=[("some", "output")]),
).select_out_keys(("some", "output"))
model_nested_export = export(model_nested, args=(), kwargs={"some_key": x})
print("exported module with nested input:", model_nested_export.module())
exported module with nested input: GraphModule()
def forward(self, some_key):
some_key, = fx_pytree.tree_flatten_spec(([], {'some_key':some_key}), self._in_spec)
add = torch.ops.aten.add.Tensor(some_key, 1); some_key = None
sub = torch.ops.aten.sub.Tensor(add, 1); add = None
return pytree.tree_unflatten((sub,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
請注意,由 module() 返回的可呼叫物件是一個純 Python 可呼叫物件,可以進而使用 compile() 進行編譯。
儲存匯出的模組¶
torch.export 有自己的序列化協議,save() 和 load()。約定俗成,應使用 “.pt2” 副檔名
>>> torch.export.save(model_export, "model.pt2")
選擇輸出¶
回想一下,tensordict.nn 的作用是保留輸出中的所有中間值,除非使用者明確要求只使用特定值。在訓練期間,這可能非常有用:可以輕鬆記錄圖的中間值,或將它們用於其他目的(例如,根據其儲存的引數重構分佈,而不是儲存 Distribution 物件本身)。也可以認為,在訓練期間,由於它們是 torch.autograd 用於計算引數梯度的計算圖的一部分,因此註冊中間值對記憶體的影響可以忽略不計。
然而,在推理期間,我們最有可能只對模型的最終樣本感興趣。因為我們希望提取模型用於與 tensordict 庫無關的用途,所以只隔離我們想要的輸出是有意義的。為此,我們有幾種選擇:
使用
selected_out_keys關鍵字引數構建TensorDictSequential(),這將指示在呼叫模組時選擇所需的條目;使用
select_out_keys()方法,該方法將就地修改out_keys屬性(可以透過reset_out_keys()恢復)。將現有例項包裝在
TensorDictSequential()中,該函式將過濾掉不需要的鍵>>> module_filtered = Seq(module, selected_out_keys=["sample"])
讓我們在選擇輸出鍵後測試模型。當提供 x 輸入時,我們期望我們的模型輸出一個對應於分佈樣本的單個張量
tensor([[-0.2766, 0.2466]], grad_fn=<SplitBackward0>)
我們看到輸出現在是單個張量,對應於分佈的樣本。我們可以從中建立一個新的匯出圖。其計算圖應該被簡化
model_export = export(model, args=(), kwargs={"x": x})
print("module:", model_export.module())
module: GraphModule(
(module): Module(
(0): Module(
(module): Module()
)
(2): Module(
(module): Module()
)
)
)
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
module_0_module_weight = getattr(self.module, "0").module.weight
module_0_module_bias = getattr(self.module, "0").module.bias
module_2_module_weight = getattr(self.module, "2").module.weight
module_2_module_bias = getattr(self.module, "2").module.bias
linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); relu = module_2_module_weight = module_2_module_bias = None
chunk = torch.ops.aten.chunk.default(linear_1, 2, -1); linear_1 = None
getitem = chunk[0]
getitem_1 = chunk[1]; chunk = None
add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus = torch.ops.aten.softplus.default(add); add = None
add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2 = broadcast_tensors[0]
getitem_3 = broadcast_tensors[1]; broadcast_tensors = getitem_3 = None
return pytree.tree_unflatten((getitem_2,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
控制取樣策略¶
我們還沒有討論 ProbabilisticTensorDictModule 如何從分佈中取樣。透過取樣,我們指的是根據特定策略在分佈定義的空間內獲取一個值。例如,在訓練期間可能希望獲得隨機樣本,但在推理時獲得確定性樣本(例如,均值或眾數)。為了解決這個問題,tensordict 利用 set_interaction_type 裝飾器和上下文管理器,它接受 InteractionType 列舉輸入
>>> with set_interaction_type(InteractionType.MEAN):
... output = module(input) # takes the input of the distribution, if ProbabilisticTensorDictModule is invoked
預設的 InteractionType 是 InteractionType.DETERMINISTIC,如果不直接實現,則為具有實數域的分佈的均值,或具有離散域的分佈的眾數。可以使用 ProbabilisticTensorDictModule 的 default_interaction_type 關鍵字引數更改此預設值。
總而言之:要控制網路的取樣策略,我們可以在建構函式中定義預設取樣策略,或者透過 set_interaction_type 上下文管理器在執行時覆蓋它。
從以下示例中可以看到,torch.export 正確響應了裝飾器的使用:如果我們要求隨機樣本,輸出與要求均值時的輸出不同
with set_interaction_type(InteractionType.RANDOM):
model_export = export(model, args=(), kwargs={"x": x})
print(model_export.module())
with set_interaction_type(InteractionType.MEAN):
model_export = export(model, args=(), kwargs={"x": x})
print(model_export.module())
GraphModule(
(module): Module(
(0): Module(
(module): Module()
)
(2): Module(
(module): Module()
)
)
)
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
module_0_module_weight = getattr(self.module, "0").module.weight
module_0_module_bias = getattr(self.module, "0").module.bias
module_2_module_weight = getattr(self.module, "2").module.weight
module_2_module_bias = getattr(self.module, "2").module.bias
linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); relu = module_2_module_weight = module_2_module_bias = None
chunk = torch.ops.aten.chunk.default(linear_1, 2, -1); linear_1 = None
getitem = chunk[0]
getitem_1 = chunk[1]; chunk = None
add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus = torch.ops.aten.softplus.default(add); add = None
add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2 = broadcast_tensors[0]
getitem_3 = broadcast_tensors[1]; broadcast_tensors = None
empty = torch.ops.aten.empty.memory_format([1, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
normal_ = torch.ops.aten.normal_.default(empty); empty = None
mul = torch.ops.aten.mul.Tensor(normal_, getitem_3); normal_ = getitem_3 = None
add_2 = torch.ops.aten.add.Tensor(getitem_2, mul); getitem_2 = mul = None
return pytree.tree_unflatten((add_2,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
GraphModule(
(module): Module(
(0): Module(
(module): Module()
)
(2): Module(
(module): Module()
)
)
)
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
module_0_module_weight = getattr(self.module, "0").module.weight
module_0_module_bias = getattr(self.module, "0").module.bias
module_2_module_weight = getattr(self.module, "2").module.weight
module_2_module_bias = getattr(self.module, "2").module.bias
linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); relu = module_2_module_weight = module_2_module_bias = None
chunk = torch.ops.aten.chunk.default(linear_1, 2, -1); linear_1 = None
getitem = chunk[0]
getitem_1 = chunk[1]; chunk = None
add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus = torch.ops.aten.softplus.default(add); add = None
add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
broadcast_tensors = torch.ops.aten.broadcast_tensors.default([getitem, clamp_min]); getitem = clamp_min = None
getitem_2 = broadcast_tensors[0]
getitem_3 = broadcast_tensors[1]; broadcast_tensors = getitem_3 = None
return pytree.tree_unflatten((getitem_2,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
這就是使用 torch.export 所需瞭解的全部內容。有關更多資訊,請參閱 官方文件。
後續步驟和進一步閱讀¶
檢視
torch.export教程,可在 此處 找到;ONNX 支援:檢視 ONNX 教程,瞭解有關此功能的更多資訊。匯出到 ONNX 與此處解釋的 torch.export 非常相似。
要在沒有 Python 環境的伺服器上部署 PyTorch 程式碼,請檢視 AOTInductor 文件。
指令碼總執行時間:(0 分 4.507 秒)