評價此頁

torch.func.stack_module_state#

torch.func.stack_module_state(models) params, buffers[原始碼]#

為使用 vmap() 進行整合準備一個 `torch.nn.Module` 列表。

給定一個由 M 個同類 nn.Module 組成的列表,返回兩個字典,它們將所有引數和緩衝區按名稱堆疊在一起。堆疊的引數是可最佳化的(即它們是 autograd 歷史中的新葉子節點,與原始引數無關,可以直接傳遞給最佳化器)。

以下是一個整合一個非常簡單的模型的示例

num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)


def wrapper(params, buffers, data):
    return torch.func.functional_call(models[0], (params, buffers), data)


params, buffers = stack_module_state(models)
output = vmap(wrapper, (0, 0, None))(params, buffers, data)

assert output.shape == (num_models, batch_size, out_features)

當存在子模組時,會遵循 state dict 的命名約定

import torch.nn as nn


class Foo(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        hidden = 4
        self.l1 = nn.Linear(in_features, hidden)
        self.l2 = nn.Linear(hidden, out_features)

    def forward(self, x):
        return self.l2(self.l1(x))


num_models = 5
in_features, out_features = 3, 3
models = [Foo(in_features, out_features) for i in range(num_models)]
params, buffers = stack_module_state(models)
print(list(params.keys()))  # "l1.weight", "l1.bias", "l2.weight", "l2.bias"

警告

所有一起堆疊的模組必須是相同的(除了它們的引數/緩衝區的值)。例如,它們應該處於相同的模式(訓練或評估)。

返回型別

tuple[dict[str, Any], dict[str, Any]]