torch.func.functionalize#
- torch.func.functionalize(func, *, remove='mutations')[原始碼]#
functionalize 是一個轉換,可以用來移除函式中的(中間)變異和別名,同時保持函式的語義。
functionalize(func)返回一個新函式,該函式與func具有相同的語義,但移除了所有中間突變。對中間張量執行的每個原地(inplace)操作:intermediate.foo_()都將被替換為其非原地(out-of-place)等價操作:intermediate_updated = intermediate.foo()。functionalize 對於將 PyTorch 程式交付給不易表示突變或別名運算子的後端或編譯器非常有用。
- 引數
func (Callable) – 一個接受一個或多個引數的 Python 函式。
remove (str) – 一個可選的字串引數,取值為 ‘mutations’ 或 ‘mutations_and_views’。如果傳入 ‘mutations’,則所有變異運算子都將被替換為其非變異等價操作。如果傳入 ‘mutations_and_views’,則此外,所有別名運算子都將被替換為其非別名等價操作。預設值:‘mutations’。
- 返回
返回一個新“功能化”的函式。它接收與
func相同的輸入,並具有相同的行為,但函式中對中間張量執行的任何變異(以及可選的別名)都將被移除。- 返回型別
functionalize 還會移除對函式輸入執行的變異(和檢視)。但是,為了保持語義,functionalize 將在轉換執行完成後“修復”變異,方法是檢測是否有張量輸入“應該”被變異,並在必要時將新資料複製回輸入。
示例
>>> import torch >>> from torch.fx.experimental.proxy_tensor import make_fx >>> from torch.func import functionalize >>> >>> # A function that uses mutations and views, but only on intermediate tensors. >>> def f(a): ... b = a + 1 ... c = b.view(-1) ... c.add_(1) ... return b ... >>> inpt = torch.randn(2) >>> >>> out1 = f(inpt) >>> out2 = functionalize(f)(inpt) >>> >>> # semantics are the same (outputs are equivalent) >>> print(torch.allclose(out1, out2)) True >>> >>> f_traced = make_fx(f)(inpt) >>> f_no_mutations_traced = make_fx(functionalize(f))(inpt) >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt) >>> >>> print(f_traced.code) def forward(self, a_1): add = torch.ops.aten.add(a_1, 1); a_1 = None view = torch.ops.aten.view(add, [-1]) add_ = torch.ops.aten.add_(view, 1); view = None return add >>> print(f_no_mutations_traced.code) def forward(self, a_1): add = torch.ops.aten.add(a_1, 1); a_1 = None view = torch.ops.aten.view(add, [-1]); add = None add_1 = torch.ops.aten.add(view, 1); view = None view_1 = torch.ops.aten.view(add_1, [2]); add_1 = None return view_1 >>> print(f_no_mutations_and_views_traced.code) def forward(self, a_1): add = torch.ops.aten.add(a_1, 1); a_1 = None view_copy = torch.ops.aten.view_copy(add, [-1]); add = None add_1 = torch.ops.aten.add(view_copy, 1); view_copy = None view_copy_1 = torch.ops.aten.view_copy(add_1, [2]); add_1 = None return view_copy_1 >>> # A function that mutates its input tensor >>> def f(a): ... b = a.view(-1) ... b.add_(1) ... return a ... >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt) >>> # >>> # All mutations and views have been removed, >>> # but there is an extra copy_ in the graph to correctly apply the mutation to the input >>> # after the function has completed. >>> print(f_no_mutations_and_views_traced.code) def forward(self, a_1): view_copy = torch.ops.aten.view_copy(a_1, [-1]) add = torch.ops.aten.add(view_copy, 1); view_copy = None view_copy_1 = torch.ops.aten.view_copy(add, [2]); add = None copy_ = torch.ops.aten.copy_(a_1, view_copy_1); a_1 = None return view_copy_1
- 有幾個值得指出的 functionalize 的“失敗模式”:
與其他 torch.func 轉換一樣,functionalize() 不適用於直接使用 .backward() 的函式。torch.autograd.grad 也是如此。如果你想使用 autograd,你可以直接用 functionalize(grad(f)) 計算梯度。
與其他 torch.func 轉換一樣,functionalize() 不適用於全域性狀態。如果你對使用非區域性狀態的檢視/變異的函式呼叫 functionalize(f),functionalization 將簡單地不執行任何操作,並將檢視/變異呼叫直接傳遞給後端。一種解決方法是確保任何非區域性狀態的建立都包裝在一個更大的函式中,然後你對該函式呼叫 functionalize。
resize_() 有一些限制:functionalize 僅適用於使用 `resize_()` 的程式,前提是正在調整大小的張量不是檢視。
as_strided() 有一些限制:functionalize 不適用於導致張量具有重疊記憶體的 as_strided() 呼叫。
最後,理解功能化的一個有用的心智模型是,大多數使用者編寫的 PyTorch 程式都是使用公共 torch API。執行時,torch 運算子通常被分解為我們內部的 C++ “ATen” API。功能化的邏輯完全發生在 ATen 層面。功能化知道如何將 ATen 中的每個別名運算子對映到其非別名等價操作(例如
tensor.view({-1})->at::view_copy(tensor, {-1})),以及如何將 ATen 中的每個變異運算子對映到其非變異等價操作(例如tensor.add_(1)->at::add(tensor, -1)),同時進行離線別名和變異跟蹤,以瞭解何時進行修復。關於哪些 ATen 運算子是別名或變異的資訊全部來自 pytorch/pytorch。