functorch.make_functional¶
-
functorch.make_functional(model, disable_autograd_tracking=False) → func, params[source]¶ 給定一個
torch.nn.Module,make_functional()會萃取狀態 (params) 並傳回模型的函式版本func。這使得可以對model之參數使用轉換。func可如下呼叫import torch import torch.nn as nn from functorch import make_functional x = torch.randn(4, 3) model = nn.Linear(3, 3) func, params = make_functional(model) func(params, x)
以下是將 grad 轉換套用於模型參數的範例。
import torch import torch.nn as nn from functorch import make_functional, grad x = torch.randn(4, 3) t = torch.randn(4, 3) model = nn.Linear(3, 3) func, params = make_functional(model) def compute_loss(params, x, t): y = func(params, x) return nn.functional.mse_loss(y, t) grad_weights = grad(compute_loss)(params, x, t)
如果模型有任何緩衝區,請使用
make_functional_with_buffers()取代。- 參數
model (torch.nn.Module) – 輸入模型。
disable_autograd_tracking (bool) – 旗標,用於停用輸出參數的梯度追蹤。傳回的 params 與原始模型的 params 集合無關。如果是 False (預設值),params 會有
requires_grad=True(亦即它們可透過正規 PyTorch 自動微分追蹤),與原始模型的 params 的 requires_grad 相符。否則,傳回的 params 會有requires_grad=False。預設值:False。如果你計畫使用正規的 PyTorch 自動微分 (例如:如果你要呼叫.backward()或torch.autograd.grad(),請設定disable_autograd_tracking=False。否則,如果你只計畫使用 functorch 的梯度轉換,請設定disable_autograd_tracking=True,以避免透過 PyTorch 自動微分不必要追蹤歷程。
警告
我們已將 functorch 整合至 PyTorch。身為整合的最後一步,從 PyTorch 2.0 開始,functorch.make_functional 已不建議使用,並會在 PyTorch >= 2.3 的未來版本中移除。請使用 torch.func.functional_call 取代;請參閱 PyTorch 2.0 發行說明和/或 torch.func 遷移指南以取得更多詳情 https://pytorch.com.tw/docs/stable/func.migrating.html