從 functorch 遷移到 torch.func#
建立日期:2025 年 6 月 11 日 | 最後更新日期:2025 年 6 月 11 日
torch.func,之前稱為“functorch”,是 PyTorch 的 類 JAX 的可組合函式變換。
functorch 最初是 `pytorch/functorch` 倉庫中的一個獨立庫。我們的目標一直是將 functorch 直接合併到 PyTorch 中,並將其作為核心 PyTorch 庫提供。
作為合併的最後一步,我們決定從一個頂級包(`functorch`)遷移到 PyTorch 的一部分,以反映函式變換如何直接整合到 PyTorch 核心中。從 PyTorch 2.0 開始,我們棄用 `import functorch`,並要求使用者遷移到我們將繼續維護的最新 API。`import functorch` 將保留幾期以維持向後相容性。
函式變換#
以下 API 是以下 functorch API 的直接替換。它們完全向後相容。
functorch API |
PyTorch API(截至 PyTorch 2.0) |
|---|---|
functorch.vmap |
|
functorch.grad |
|
functorch.vjp |
|
functorch.jvp |
|
functorch.jacrev |
|
functorch.jacfwd |
|
functorch.hessian |
|
functorch.functionalize |
此外,如果您使用的是 torch.autograd.functional API,請嘗試使用 torch.func 的等效 API。在許多情況下,torch.func 的函式變換更具可組合性,效能也更好。
torch.autograd.functional API |
torch.func API(截至 PyTorch 2.0) |
|---|---|
NN 模組實用工具#
我們更改了 API,以將函式變換應用於 NN 模組,使其更符合 PyTorch 的設計理念。新 API 不同,因此請仔細閱讀本節。
functorch.make_functional#
torch.func.functional_call() 是 functorch.make_functional 和 functorch.make_functional_with_buffers 的替代品。但它不是直接替換。
如果您急需,可以使用 此 gist 中的輔助函式 來模擬 functorch.make_functional 和 functorch.make_functional_with_buffers 的行為。我們建議直接使用 torch.func.functional_call(),因為它是一個更明確、更靈活的 API。
具體來說,functorch.make_functional 返回一個函式式模組和引數。函式式模組接受引數、模型輸入作為引數。torch.func.functional_call() 允許使用新的引數、緩衝區和輸入呼叫現有模組的前向傳遞。
這裡有一個例子,說明如何使用 functorch 與 torch.func 計算模型引數的梯度。
# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
fmodel, params = functorch.make_functional(model)
def compute_loss(params, inputs, targets):
prediction = fmodel(params, inputs)
return torch.nn.functional.mse_loss(prediction, targets)
grads = functorch.grad(compute_loss)(params, inputs, targets)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
params = dict(model.named_parameters())
def compute_loss(params, inputs, targets):
prediction = torch.func.functional_call(model, params, (inputs,))
return torch.nn.functional.mse_loss(prediction, targets)
grads = torch.func.grad(compute_loss)(params, inputs, targets)
這裡有一個計算模型引數雅可比矩陣的例子。
# ---------------
# using functorch
# ---------------
import torch
import functorch
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
fmodel, params = functorch.make_functional(model)
jacobians = functorch.jacrev(fmodel)(params, inputs)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import torch
from torch.func import jacrev, functional_call
inputs = torch.randn(64, 3)
model = torch.nn.Linear(3, 3)
params = dict(model.named_parameters())
# jacrev computes jacobians of argnums=0 by default.
# We set it to 1 to compute jacobians of params
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))
請注意,為了節約記憶體,您應該只保留引數的單個副本。model.named_parameters() 不會複製引數。如果在模型訓練中原地更新模型的引數,那麼您的模型 `nn.Module` 擁有引數的單個副本,一切都正常。
但是,如果您想將引數儲存在一個字典中並進行非原地更新,那麼就會存在兩個引數副本:字典中的一個,以及 `model` 中的一個。在這種情況下,您應該透過將 `model` 轉換為元裝置(`model.to('meta')`)來使其不持有記憶體。
functorch.combine_state_for_ensemble#
請使用 torch.func.stack_module_state() 來代替 functorch.combine_state_for_ensemble。 torch.func.stack_module_state() 返回兩個字典,一個包含堆疊的引數,另一個包含堆疊的緩衝區,然後這些可以與 torch.vmap() 和 torch.func.functional_call() 一起用於集合。
例如,這是一個關於如何對一個非常簡單的模型進行集合的例子。
import torch
num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)
# ---------------
# using functorch
# ---------------
import functorch
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
# ------------------------------------
# using torch.func (as of PyTorch 2.0)
# ------------------------------------
import copy
# Construct a version of the model with no memory by putting the Tensors on
# the meta device.
base_model = copy.deepcopy(models[0])
base_model.to('meta')
params, buffers = torch.func.stack_module_state(models)
# It is possible to vmap directly over torch.func.functional_call,
# but wrapping it in a function makes it clearer what is going on.
def call_single_model(params, buffers, data):
return torch.func.functional_call(base_model, (params, buffers), (data,))
output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
functorch.compile#
我們不再支援 functorch.compile(也稱為 AOTAutograd)作為 PyTorch 中編譯的前端;我們已將 AOTAutograd 整合到 PyTorch 的編譯流程中。如果您是使用者,請改用 torch.compile()。