注意
轉到末尾 下載完整的示例程式碼。
模型整合#
創建於: 2023 年 3 月 15 日 | 最後更新: 2025 年 10 月 2 日 | 最後驗證: 2024 年 11 月 5 日
本教程演示瞭如何使用 torch.vmap 實現模型整合的向量化。
什麼是模型整合?#
模型整合是將多個模型的預測結果組合在一起。傳統上,這是透過分別對每個模型執行一些輸入,然後組合預測結果來實現的。但是,如果您執行的是具有相同架構的模型,則可以使用 torch.vmap 將它們組合在一起。vmap 是一個函式變換,可以將函式對映到輸入張量的維度上。它的一個用例是消除 for 迴圈並透過向量化來加速它們。
讓我們透過一個簡單 MLP 的整合來演示這一點。
注意
本教程需要 PyTorch 2.0.0 或更高版本。
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(0)
# Here's a simple MLP
class SimpleMLP(nn.Module):
def __init__(self):
super(SimpleMLP, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 128)
self.fc3 = nn.Linear(128, 10)
def forward(self, x):
x = x.flatten(1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
return x
讓我們生成一批虛擬資料,並假設我們正在處理 MNIST 資料集。因此,虛擬影像的大小為 28x28,我們有一個大小為 64 的小批次。此外,假設我們想合併 10 個不同模型的預測結果。
device = torch.accelerator.current_accelerator()
num_models = 10
data = torch.randn(100, 64, 1, 28, 28, device=device)
targets = torch.randint(10, (6400,), device=device)
models = [SimpleMLP().to(device) for _ in range(num_models)]
我們有幾種生成預測的選項。也許我們想給每個模型一個不同的隨機小批次資料。或者,也許我們想將相同的小批次資料透過每個模型(例如,如果我們正在測試不同模型初始化的效果)。
選項 1:每個模型使用不同的小批次
minibatches = data[:num_models]
predictions_diff_minibatch_loop = [model(minibatch) for model, minibatch in zip(models, minibatches)]
選項 2:相同的小批次
使用 vmap 對整合進行向量化#
讓我們使用 vmap 來加速 for 迴圈。我們必須首先準備好模型以供 vmap 使用。
首先,讓我們透過堆疊每個引數來組合模型的狀態。例如,model[i].fc1.weight 的形狀是 [784, 128];我們將堆疊 10 個模型的 .fc1.weight 以生成一個形狀為 [10, 784, 128] 的大權重。
PyTorch 提供了 torch.func.stack_module_state 便利函式來完成此操作。
from torch.func import stack_module_state
params, buffers = stack_module_state(models)
接下來,我們需要定義一個函式來 vmap。該函式應接收引數、緩衝區和輸入,並使用這些引數、緩衝區和輸入來執行模型。我們將使用 torch.func.functional_call 來提供幫助。
from torch.func import functional_call
import copy
# Construct a "stateless" version of one of the models. It is "stateless" in
# the sense that the parameters are meta Tensors and do not have storage.
base_model = copy.deepcopy(models[0])
base_model = base_model.to('meta')
def fmodel(params, buffers, x):
return functional_call(base_model, (params, buffers), (x,))
選項 1:使用不同的小批次為每個模型獲取預測。
預設情況下,vmap 會將函式對映到傳入函式的所有輸入的第一個維度上。在使用 stack_module_state 之後,每個 params 和緩衝區都將在前面有一個大小為“num_models”的額外維度,並且小批次有一個大小為“num_models”的維度。
print([p.size(0) for p in params.values()]) # show the leading 'num_models' dimension
assert minibatches.shape == (num_models, 64, 1, 28, 28) # verify minibatch has leading dimension of size 'num_models'
from torch import vmap
predictions1_vmap = vmap(fmodel)(params, buffers, minibatches)
# verify the ``vmap`` predictions match the
assert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5)
[10, 10, 10, 10, 10, 10]
選項 2:使用相同的小批次資料獲取預測。
vmap 有一個 in_dims 引數,用於指定要對映的維度。透過使用 None,我們告訴 vmap 我們希望將相同的小批次應用於所有 10 個模型。
predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)
assert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)
快速提示:vmap 可轉換的函式型別存在一些限制。最適合轉換的函式是純函式:一個輸出僅由無副作用(例如突變)的輸入決定的函式。vmap 無法處理任意 Python 資料結構的突變,但可以處理許多原地 PyTorch 操作。
效能#
對效能資料感到好奇嗎?以下是數字顯示。
from torch.utils.benchmark import Timer
without_vmap = Timer(
stmt="[model(minibatch) for model, minibatch in zip(models, minibatches)]",
globals=globals())
with_vmap = Timer(
stmt="vmap(fmodel)(params, buffers, minibatches)",
globals=globals())
print(f'Predictions without vmap {without_vmap.timeit(100)}')
print(f'Predictions with vmap {with_vmap.timeit(100)}')
Predictions without vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f4a83d9c3a0>
[model(minibatch) for model, minibatch in zip(models, minibatches)]
1.52 ms
1 measurement, 100 runs , 1 thread
Predictions with vmap <torch.utils.benchmark.utils.common.Measurement object at 0x7f4a83d9f850>
vmap(fmodel)(params, buffers, minibatches)
523.24 us
1 measurement, 100 runs , 1 thread
使用 vmap 可以大幅提高速度!
總的來說,使用 vmap 進行向量化應該比在 for 迴圈中執行函式要快,並且與手動批處理相當。但也有一些例外,例如如果我們尚未為特定操作實現 vmap 規則,或者底層核心未針對舊版硬體(GPU)進行最佳化。如果您遇到任何這些情況,請透過在 GitHub 上提交 issue 來告知我們。
指令碼總執行時間: (0 分鐘 0.782 秒)