評價此頁

從 functorch 遷移到 torch.func#

建立日期:2025 年 6 月 11 日 | 最後更新日期:2025 年 6 月 11 日

torch.func,之前稱為“functorch”,是 PyTorch 的 類 JAX 的可組合函式變換。

functorch 最初是 `pytorch/functorch` 倉庫中的一個獨立庫。我們的目標一直是將 functorch 直接合併到 PyTorch 中,並將其作為核心 PyTorch 庫提供。

作為合併的最後一步,我們決定從一個頂級包(`functorch`)遷移到 PyTorch 的一部分,以反映函式變換如何直接整合到 PyTorch 核心中。從 PyTorch 2.0 開始,我們棄用 `import functorch`,並要求使用者遷移到我們將繼續維護的最新 API。`import functorch` 將保留幾期以維持向後相容性。

函式變換#

以下 API 是以下 functorch API 的直接替換。它們完全向後相容。

functorch API

PyTorch API(截至 PyTorch 2.0)

functorch.vmap

torch.vmap()torch.func.vmap()

functorch.grad

torch.func.grad()

functorch.vjp

torch.func.vjp()

functorch.jvp

torch.func.jvp()

functorch.jacrev

torch.func.jacrev()

functorch.jacfwd

torch.func.jacfwd()

functorch.hessian

torch.func.hessian()

functorch.functionalize

torch.func.functionalize()

此外,如果您使用的是 torch.autograd.functional API,請嘗試使用 torch.func 的等效 API。在許多情況下,torch.func 的函式變換更具可組合性,效能也更好。

NN 模組實用工具#

我們更改了 API,以將函式變換應用於 NN 模組,使其更符合 PyTorch 的設計理念。新 API 不同,因此請仔細閱讀本節。

functorch.make_functional#

torch.func.functional_call()functorch.make_functionalfunctorch.make_functional_with_buffers 的替代品。但它不是直接替換。

如果您急需,可以使用 此 gist 中的輔助函式 來模擬 functorch.make_functional 和 functorch.make_functional_with_buffers 的行為。我們建議直接使用 torch.func.functional_call(),因為它是一個更明確、更靈活的 API。

具體來說,functorch.make_functional 返回一個函式式模組和引數。函式式模組接受引數、模型輸入作為引數。torch.func.functional_call() 允許使用新的引數、緩衝區和輸入呼叫現有模組的前向傳遞。

這裡有一個例子,說明如何使用 functorch 與 torch.func 計算模型引數的梯度。

# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

fmodel, params = functorch.make_functional(model)

def compute_loss(params, inputs, targets):
    prediction = fmodel(params, inputs)
    return torch.nn.functional.mse_loss(prediction, targets)

grads = functorch.grad(compute_loss)(params, inputs, targets)

# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

params = dict(model.named_parameters())

def compute_loss(params, inputs, targets):
    prediction = torch.func.functional_call(model, params, (inputs,))
    return torch.nn.functional.mse_loss(prediction, targets)

grads = torch.func.grad(compute_loss)(params, inputs, targets)

這裡有一個計算模型引數雅可比矩陣的例子。

# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

fmodel, params = functorch.make_functional(model)
jacobians = functorch.jacrev(fmodel)(params, inputs)

# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
from torch.func import jacrev, functional_call
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)

params = dict(model.named_parameters())
# jacrev computes jacobians of argnums=0 by default.
# We set it to 1 to compute jacobians of params
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))

請注意,為了節約記憶體,您應該只保留引數的單個副本。model.named_parameters() 不會複製引數。如果在模型訓練中原地更新模型的引數,那麼您的模型 `nn.Module` 擁有引數的單個副本,一切都正常。

但是,如果您想將引數儲存在一個字典中並進行非原地更新,那麼就會存在兩個引數副本:字典中的一個,以及 `model` 中的一個。在這種情況下,您應該透過將 `model` 轉換為元裝置(`model.to('meta')`)來使其不持有記憶體。

functorch.combine_state_for_ensemble#

請使用 torch.func.stack_module_state() 來代替 functorch.combine_state_for_ensembletorch.func.stack_module_state() 返回兩個字典,一個包含堆疊的引數,另一個包含堆疊的緩衝區,然後這些可以與 torch.vmap()torch.func.functional_call() 一起用於集合。

例如,這是一個關於如何對一個非常簡單的模型進行集合的例子。

import torch
num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)

# ---------------
# using functorch
# ---------------
import functorch
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)

# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import copy

# Construct a version of the model with no memory by putting the Tensors on
# the meta device.
base_model = copy.deepcopy(models[0])
base_model.to('meta')

params, buffers = torch.func.stack_module_state(models)

# It is possible to vmap directly over torch.func.functional_call,
# but wrapping it in a function makes it clearer what is going on.
def call_single_model(params, buffers, data):
    return torch.func.functional_call(base_model, (params, buffers), (data,))

output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)

functorch.compile#

我們不再支援 functorch.compile(也稱為 AOTAutograd)作為 PyTorch 中編譯的前端;我們已將 AOTAutograd 整合到 PyTorch 的編譯流程中。如果您是使用者,請改用 torch.compile()