torch.fx#
創建於: 2020年12月15日 | 最後更新於: 2025年07月15日
概述#
FX 是一個供開發者用來轉換 nn.Module 例項的工具包。FX 包含三個主要元件:符號追蹤器、中間表示 和 Python 程式碼生成。這些元件協同工作的演示
import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%param : [num_users=1] = get_attr[target=param]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
"""
# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
param = self.param
add = x + param; x = param = None
linear = self.linear(add); add = None
clamp = linear.clamp(min = 0.0, max = 1.0); linear = None
return clamp
"""
符號追蹤器 對 Python 程式碼執行“符號執行”。它透過程式碼傳遞稱為代理(Proxies)的假值。對這些代理的操作會被記錄下來。有關符號追蹤的更多資訊可以在 symbolic_trace() 和 Tracer 文件中找到。
中間表示 是用於存放符號追蹤期間記錄的操作的容器。它包含一系列節點(Nodes),這些節點代表函式輸入、呼叫點(指向函式、方法或 torch.nn.Module 例項)以及返回值。有關 IR 的更多資訊可以在 Graph 文件中找到。IR 是應用轉換的格式。
Python 程式碼生成 是 FX 成為 Python 到 Python(或模組到模組)轉換工具包的原因。對於每個 Graph IR,我們可以生成與 Graph 語義相匹配的有效 Python 程式碼。此功能封裝在 GraphModule 中,它是一個 torch.nn.Module 例項,其中包含一個 Graph 以及從 Graph 生成的 forward 方法。
總而言之,這個元件管道(符號追蹤 -> 中間表示 -> 轉換 -> Python 程式碼生成)構成了 FX 的 Python 到 Python 轉換管道。此外,這些元件也可以單獨使用。例如,符號追蹤可以單獨用於捕獲程式碼的一種形式以便進行分析(而非轉換)。程式碼生成可用於以程式設計方式生成模型,例如從配置檔案生成。FX 有許多用途!
可以在 examples 儲存庫中找到幾個轉換示例。
編寫轉換#
什麼是 FX 轉換?本質上,它是一個如下所示的函式。
import torch
import torch.fx
def transform(m: nn.Module,
tracer_class : type = torch.fx.Tracer) -> torch.nn.Module:
# Step 1: Acquire a Graph representing the code in `m`
# NOTE: torch.fx.symbolic_trace is a wrapper around a call to
# fx.Tracer.trace and constructing a GraphModule. We'll
# split that out in our transform to allow the caller to
# customize tracing behavior.
graph : torch.fx.Graph = tracer_class().trace(m)
# Step 2: Modify this Graph or create a new one
graph = ...
# Step 3: Construct a Module to return
return torch.fx.GraphModule(m, graph)
您的轉換將接收一個 torch.nn.Module,從中獲取一個 Graph,進行一些修改,然後返回一個新的 torch.nn.Module。您應該將 FX 轉換返回的 torch.nn.Module 視為與常規 torch.nn.Module 相同 —— 您可以將其傳遞給另一個 FX 轉換,或者執行它。確保您的 FX 轉換的輸入和輸出是 torch.nn.Module 將允許組合。
注意
也可以修改現有的 GraphModule 而不是建立新的,如下所示:
import torch
import torch.fx
def transform(m : nn.Module) -> nn.Module:
gm : torch.fx.GraphModule = torch.fx.symbolic_trace(m)
# Modify gm.graph
# <...>
# Recompile the forward() method of `gm` from its Graph
gm.recompile()
return gm
請注意,您必須呼叫 GraphModule.recompile() 來使 GraphModule 上生成的 forward() 方法與修改後的 Graph 同步。
鑑於您已經傳入一個已追蹤成 Graph 的 torch.nn.Module,現在有兩種主要方法可以用來構建一個新的 Graph。
快速圖入門#
圖語義的完整介紹可以在 Graph 文件中找到,但我們在這裡將介紹基礎知識。 Graph 是一個表示 GraphModule 中方法的_資料結構。它所需的資訊是:
方法有哪些輸入?
方法內部執行的操作是什麼?
方法的輸出(即返回值)是什麼?
所有這三個概念都用 Node 例項表示。讓我們用一個簡單的例子來看看這是什麼意思:
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return torch.topk(torch.sum(
self.linear(x + self.linear.weight).relu(), dim=-1), 3)
m = MyModule()
gm = torch.fx.symbolic_trace(m)
gm.graph.print_tabular()
這裡我們定義了一個模組 MyModule 用於演示,例項化它,符號追蹤它,然後呼叫 Graph.print_tabular() 方法來列印一個表格,顯示這個 Graph 的節點。
操作碼 |
名稱 |
目標 |
引數 |
關鍵字引數 |
|---|---|---|---|---|
佔位符 |
x |
x |
() |
{} |
get_attr |
linear_weight |
linear.weight |
() |
{} |
call_function |
add_1 |
(x, linear_weight) |
{} |
|
call_module |
linear_1 |
linear |
(add_1,) |
{} |
call_method |
relu_1 |
relu |
(linear_1,) |
{} |
call_function |
sum_1 |
<內建方法 sum …> |
(relu_1,) |
{‘dim’: -1} |
call_function |
topk_1 |
<內建方法 topk …> |
(sum_1, 3) |
{} |
output |
output |
output |
(topk_1,) |
{} |
我們可以利用這些資訊來回答上面提出的問題。
方法有哪些輸入?在 FX 中,方法輸入透過特殊的
placeholder節點指定。在這種情況下,我們有一個帶有target為x的placeholder節點,這意味著我們有一個名為 x 的單個(非 self)引數。方法內部的操作是什麼?
get_attr、call_function、call_module和call_method節點代表方法中的操作。所有這些操作的完整語義可以在Node文件中找到。方法的返回值是什麼?
Graph中的返回值由特殊的output節點指定。
鑑於我們現在瞭解了 FX 中程式碼表示的基本知識,我們現在可以探索如何編輯 Graph。
圖操作#
直接圖操作#
構建新 Graph 的一種方法是直接操作舊的圖。為此,我們可以簡單地獲取從符號追蹤獲得的 Graph 並對其進行修改。例如,假設我們希望將 torch.add() 呼叫替換為 torch.mul() 呼叫。
import torch
import torch.fx
# Sample module
class M(torch.nn.Module):
def forward(self, x, y):
return torch.add(x, y)
def transform(m: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph : fx.Graph = tracer_class().trace(m)
# FX represents its Graph as an ordered list of
# nodes, so we can iterate through them.
for node in graph.nodes:
# Checks if we're calling a function (i.e:
# torch.add)
if node.op == 'call_function':
# The target attribute is the function
# that call_function calls.
if node.target == torch.add:
node.target = torch.mul
graph.lint() # Does some checks to make sure the
# Graph is well-formed.
return fx.GraphModule(m, graph)
我們還可以進行更復雜的 Graph 重寫,例如刪除或追加節點。為了輔助這些轉換,FX 在 Graph 文件中提供了用於轉換圖的實用函式。下面有一個使用這些 API 追加 torch.relu() 呼叫的示例。
# Specifies the insertion point. Any nodes added to the
# Graph within this scope will be inserted after `node`
with traced.graph.inserting_after(node):
# Insert a new `call_function` node calling `torch.relu`
new_node = traced.graph.call_function(
torch.relu, args=(node,))
# We want all places that used the value of `node` to
# now use that value after the `relu` call we've added.
# We use the `replace_all_uses_with` API to do this.
node.replace_all_uses_with(new_node)
對於僅包含替換的簡單轉換,您還可以使用 子圖重寫器。
使用 replace_pattern() 進行子圖重寫#
FX 還提供了比直接圖操作更高級別的自動化。 replace_pattern() API 基本上是一個用於編輯 Graph 的“查詢/替換”工具。它允許您指定一個 pattern 和一個 replacement 函式,它將追蹤這些函式,在 pattern 圖中找到操作組的例項,然後用 replacement 圖的副本替換這些例項。這有助於極大地自動化繁瑣的圖操作程式碼,因為隨著轉換變得更復雜,程式碼可能會變得笨拙。
代理/重追蹤#
操作 Graph 的另一種方法是重用符號追蹤中使用的 Proxy 機制。例如,假設我們想編寫一個將 PyTorch 函式分解為更小操作的轉換。它會將每個 F.relu(x) 呼叫轉換為 (x > 0) * x。一種可能性是執行所需圖重寫,在 F.relu 之後插入比較和乘法,然後清理原始 F.relu。但是,我們可以透過使用 Proxy 物件自動將操作記錄到 Graph 中來自動化此過程。
使用此方法,我們將要插入的操作編寫為常規 PyTorch 程式碼,並使用 Proxy 物件作為引數來呼叫該程式碼。這些 Proxy 物件將捕獲對它們的執行的操作,並將它們追加到 Graph。
# Note that this decomposition rule can be read as regular Python
def relu_decomposition(x):
return (x > 0) * x
decomposition_rules = {}
decomposition_rules[F.relu] = relu_decomposition
def decompose(model: torch.nn.Module,
tracer_class : type = fx.Tracer) -> torch.nn.Module:
"""
Decompose `model` into smaller constituent operations.
Currently,this only supports decomposing ReLU into its
mathematical definition: (x > 0) * x
"""
graph : fx.Graph = tracer_class().trace(model)
new_graph = fx.Graph()
env = {}
tracer = torch.fx.proxy.GraphAppendingTracer(new_graph)
for node in graph.nodes:
if node.op == 'call_function' and node.target in decomposition_rules:
# By wrapping the arguments with proxies,
# we can dispatch to the appropriate
# decomposition rule and implicitly add it
# to the Graph by symbolically tracing it.
proxy_args = [
fx.Proxy(env[x.name], tracer) if isinstance(x, fx.Node) else x for x in node.args]
output_proxy = decomposition_rules[node.target](*proxy_args)
# Operations on `Proxy` always yield new `Proxy`s, and the
# return value of our decomposition rule is no exception.
# We need to extract the underlying `Node` from the `Proxy`
# to use it in subsequent iterations of this transform.
new_node = output_proxy.node
env[node.name] = new_node
else:
# Default case: we don't have a decomposition rule for this
# node, so just copy the node over into the new graph.
new_node = new_graph.node_copy(node, lambda x: env[x.name])
env[node.name] = new_node
return fx.GraphModule(model, new_graph)
除了避免顯式圖操作外,使用 Proxy 還允許您將重寫規則指定為本機 Python 程式碼。對於需要大量重寫規則的轉換(例如 vmap 或 grad),這通常可以提高規則的可讀性和可維護性。請注意,在呼叫 Proxy 時,我們還傳遞了一個指向底層變數 graph 的追蹤器。這是為了以防萬一操作是 n 元的(例如,add 是二元運算子),對 Proxy 的呼叫不會建立多個圖追蹤器例項,這可能導致意外的執行時錯誤。我們尤其推薦這種使用 Proxy 的方法,因為底層運算子不能安全地假定為一元運算子。
直譯器模式#
FX 中一個有用的程式碼組織模式是遍歷 Graph 中的所有 Node 並執行它們。這可用於多種用途,包括對流經圖的值進行執行時分析,或透過使用 Proxy 進行重追蹤來轉換程式碼。例如,假設我們想執行一個 GraphModule,並在執行時看到 torch.Tensor 的形狀和 dtype 屬性時記錄它們。這可能看起來像:
import torch
import torch.fx
from torch.fx.node import Node
from typing import Dict
class ShapeProp:
"""
Shape propagation. This class takes a `GraphModule`.
Then, its `propagate` method executes the `GraphModule`
node-by-node with the given arguments. As each operation
executes, the ShapeProp class stores away the shape and
element type for the output values of each operation on
the `shape` and `dtype` attributes of the operation's
`Node`.
"""
def __init__(self, mod):
self.mod = mod
self.graph = mod.graph
self.modules = dict(self.mod.named_modules())
def propagate(self, *args):
args_iter = iter(args)
env : Dict[str, Node] = {}
def load_arg(a):
return torch.fx.graph.map_arg(a, lambda n: env[n.name])
def fetch_attr(target : str):
target_atoms = target.split('.')
attr_itr = self.mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr
for node in self.graph.nodes:
if node.op == 'placeholder':
result = next(args_iter)
elif node.op == 'get_attr':
result = fetch_attr(node.target)
elif node.op == 'call_function':
result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
elif node.op == 'call_method':
self_obj, *args = load_arg(node.args)
kwargs = load_arg(node.kwargs)
result = getattr(self_obj, node.target)(*args, **kwargs)
elif node.op == 'call_module':
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
# This is the only code specific to shape propagation.
# you can delete this `if` branch and this becomes
# a generic GraphModule interpreter.
if isinstance(result, torch.Tensor):
node.shape = result.shape
node.dtype = result.dtype
env[node.name] = result
return load_arg(self.graph.result)
如您所見,FX 的完整直譯器並不複雜,但可能非常有用。為了方便使用這種模式,我們提供了 Interpreter 類,它包含了上述邏輯,並透過方法覆蓋來覆蓋直譯器執行的某些方面。
除了執行操作之外,我們還可以透過將 Proxy 值傳遞給直譯器來生成新的 Graph。類似地,我們提供了 Transformer 類來包含此模式。 Transformer 的行為類似於 Interpreter,但不是呼叫 run 方法來獲取模組的具體輸出值,而是呼叫 Transformer.transform() 方法來返回一個新 GraphModule,該 GraphModule 應用了您作為覆蓋方法安裝的任何轉換規則。
除錯#
引言#
在編寫轉換的過程中,我們的程式碼可能並不總是正確的。在這種情況下,我們可能需要進行一些除錯。關鍵是反向工作:首先,檢查呼叫生成模組的結果以證明或證偽正確性。然後,檢查並除錯生成的程式碼。然後,除錯導致生成程式碼的轉換過程。
如果您不熟悉偵錯程式,請參閱輔助部分 可用偵錯程式。
轉換創作中的常見陷阱#
不確定的
set迭代順序。在 Python 中,set資料型別是無序的。使用set來包含諸如Node等物件的集合,可能會導致意外的非確定性。一個例子是迭代一組Node以將它們插入Graph。由於set資料型別是無序的,輸出程式中操作的順序將是非確定性的,並且可能在程式呼叫之間發生變化。推薦的替代方法是使用dict資料型別,它從 Python 3.7 開始(以及從 cpython 3.6 開始)是按插入順序排序的。透過儲存在dict的鍵中的值,可以等效地使用dict來實現集合去重。
檢查模組的正確性#
由於大多數深度學習模組的輸出是浮點 torch.Tensor 例項,因此檢查兩個 torch.nn.Module 的結果是否等效不像執行簡單的相等檢查那樣直接。為了說明這一點,讓我們舉一個例子:
import torch
import torch.fx
import torchvision.models as models
def transform(m : torch.nn.Module) -> torch.nn.Module:
gm = torch.fx.symbolic_trace(m)
# Imagine we're doing some transforms here
# <...>
gm.recompile()
return gm
resnet18 = models.resnet18()
transformed_resnet18 = transform(resnet18)
input_image = torch.randn(5, 3, 224, 224)
assert resnet18(input_image) == transformed_resnet18(input_image)
"""
RuntimeError: Boolean value of Tensor with more than one value is ambiguous
"""
在這裡,我們嘗試使用 == 相等運算子檢查兩個深度學習模型的輸出值是否相等。然而,這既不明
確,因為該運算子返回的是張量而不是布林值,而且由於浮點值的比較應該使用誤差範圍(或 epsilon)來考慮浮點運算的非交換性(更多詳細資訊請參閱此處)。我們可以改用 torch.allclose(),它將提供一個近似比較,同時考慮相對和絕對容差閾值。
assert torch.allclose(resnet18(input_image), transformed_resnet18(input_image))
這是我們工具箱中用於檢查轉換後的模組是否按預期行為與參考實現相比的第一個工具。
除錯生成的程式碼#
由於 FX 在 GraphModule 上生成 forward() 函式,因此使用傳統的除錯技術,如 print 語句或 pdb,並不那麼直接。幸運的是,我們有幾種技術可用於除錯生成的程式碼。
使用 pdb#
呼叫 pdb 進入正在執行的程式。雖然表示 Graph 的程式碼不在任何原始檔中,但在呼叫前向傳播時,我們仍然可以透過 pdb 手動進入它。
import torch
import torch.fx
import torchvision.models as models
def my_pass(inp: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
graph = tracer_class().trace(inp)
# Transformation logic here
# <...>
# Return new Module
return fx.GraphModule(inp, graph)
my_module = models.resnet18()
my_module_transformed = my_pass(my_module)
input_value = torch.randn(5, 3, 224, 224)
# When this line is executed at runtime, we will be dropped into an
# interactive `pdb` prompt. We can use the `step` or `s` command to
# step into the execution of the next line
import pdb; pdb.set_trace()
my_module_transformed(input_value)
列印生成的程式碼#
如果您想多次執行相同的程式碼,那麼使用 pdb 導航到正確的程式碼可能會有些繁瑣。在這種情況下,一種方法是簡單地將生成的 forward 傳遞複製到您的程式碼中,然後從那裡進行檢查。
# Assume that `traced` is a GraphModule that has undergone some
# number of transforms
# Copy this code for later
print(traced)
# Print the code generated from symbolic tracing. This outputs:
"""
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1
"""
# Subclass the original Module
class SubclassM(M):
def __init__(self):
super().__init__()
# Paste the generated `forward` function (the one we printed and
# copied above) here
def forward(self, y):
x = self.x
add_1 = x + y; x = y = None
return add_1
# Create an instance of the original, untraced Module. Then, create an
# instance of the Module with the copied `forward` function. We can
# now compare the output of both the original and the traced version.
pre_trace = M()
post_trace = SubclassM()
使用 GraphModule 中的 to_folder 函式#
GraphModule.to_folder() 是 GraphModule 中的一個方法,它允許您將生成的 FX 程式碼轉儲到一個資料夾。雖然像 列印生成的程式碼 中那樣將前向傳遞複製到程式碼中通常就足夠了,但使用 to_folder 檢查模組和引數可能更容易。
m = symbolic_trace(M())
m.to_folder("foo", "Bar")
from foo import Bar
y = Bar()
執行上述示例後,我們可以檢視 foo/module.py 中的程式碼,並根據需要進行修改(例如,新增 print 語句或使用 pdb)來除錯生成的程式碼。
除錯轉換#
現在我們已經確定某個轉換產生了不正確的程式碼,是時候除錯轉換本身了。首先,我們將檢查文件中的 符號追蹤的侷限性 部分。一旦我們確認追蹤按預期工作,目標就是弄清楚在 GraphModule 轉換過程中出了什麼問題。 編寫轉換 中可能有一個快速答案,但如果沒有,有幾種方法可以檢查我們的追蹤模組。
# Sample Module
class M(torch.nn.Module):
def forward(self, x, y):
return x + y
# Create an instance of `M`
m = M()
# Symbolically trace an instance of `M` (returns a GraphModule). In
# this example, we'll only be discussing how to inspect a
# GraphModule, so we aren't showing any sample transforms for the
# sake of brevity.
traced = symbolic_trace(m)
# Print the code produced by tracing the module.
print(traced)
# The generated `forward` function is:
"""
def forward(self, x, y):
add = x + y; x = y = None
return add
"""
# Print the internal Graph.
print(traced.graph)
# This print-out returns:
"""
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
return add
"""
# Print a tabular representation of the internal Graph.
traced.graph.print_tabular()
# This gives us:
"""
opcode name target args kwargs
------------- ------ ----------------------- ------ --------
placeholder x x () {}
placeholder y y () {}
call_function add <built-in function add> (x, y) {}
output output output (add,) {}
"""
使用上述實用函式,我們可以比較我們追蹤的模組在應用轉換之前和之後的狀態。有時,簡單的目視比較足以追溯錯誤。如果仍然不清楚問題所在,偵錯程式如 pdb 是一個不錯的選擇。
基於上面的例子,考慮以下程式碼:
# Sample user-defined function
def transform_graph(module: torch.nn.Module, tracer_class : type = fx.Tracer) -> torch.nn.Module:
# Get the Graph from our traced Module
g = tracer_class().trace(module)
"""
Transformations on `g` go here
"""
return fx.GraphModule(module, g)
# Transform the Graph
transformed = transform_graph(traced)
# Print the new code after our transforms. Check to see if it was
# what we expected
print(transformed)
使用上面的例子,假設 print(traced) 的呼叫顯示我們的轉換存在錯誤。我們想使用偵錯程式找出問題所在。我們啟動一個 pdb 會話。我們可以透過在 transform_graph(traced) 上設定斷點,然後按 s “步入” transform_graph(traced) 的呼叫來檢視轉換過程中發生的情況。
我們也可以嘗試編輯 print_tabular 方法來列印節點在圖中的不同屬性。(例如,我們可能想檢視節點的 input_nodes 和 users。)
可用偵錯程式#
最常見的 Python 偵錯程式是 pdb。您可以透過在命令列中鍵入 python -m pdb FILENAME.py 來以“除錯模式”啟動程式,其中 FILENAME 是您要除錯的檔名。之後,您可以使用 pdb 除錯命令逐步執行正在執行的程式。通常,在啟動 pdb 時設定一個斷點(b LINE-NUMBER),然後呼叫 c 來執行程式直到該點。這可以避免您必須逐行執行(使用 s 或 n)才能到達要檢查的程式碼部分。或者,您可以將 import pdb; pdb.set_trace() 放在您想中斷的行之前。如果您添加了 pdb.set_trace(),程式在執行時會自動進入除錯模式。(換句話說,您只需在命令列中鍵入 python FILENAME.py 而不是 python -m pdb FILENAME.py。)一旦您的檔案在除錯模式下執行,您就可以使用某些命令逐步執行程式碼並檢查程式的內部狀態。線上有許多關於 pdb 的優秀教程,包括 RealPython 的 “Python 除錯與 Pdb”。
PyCharm 或 VSCode 等 IDE 通常內建有偵錯程式。在您的 IDE 中,您可以選擇 a) 透過在 IDE 中開啟終端視窗(例如,VSCode 中的 View → Terminal)來使用 pdb,或者 b) 使用內建偵錯程式(通常是 pdb 的圖形包裝器)。
符號追蹤的侷限性#
FX 使用一種稱為符號追蹤(也稱為 符號執行)的系統,以可轉換/可分析的形式捕獲程式的語義。該系統是追蹤的,因為它執行程式(實際上是一個 torch.nn.Module 或函式)來記錄操作。它是符號的,因為在此執行期間流經程式的資料不是真實資料,而是符號(在 FX 術語中稱為 Proxy)。
儘管符號追蹤適用於大多數神經網路程式碼,但它也有一些侷限性。
動態控制流#
符號追蹤的主要限制是它目前不支援動態控制流。也就是說,迴圈或 if 語句,其條件可能取決於程式的輸入值。
例如,讓我們檢查以下程式:
def func_to_trace(x):
if x.sum() > 0:
return torch.relu(x)
else:
return torch.neg(x)
traced = torch.fx.symbolic_trace(func_to_trace)
"""
<...>
File "dyn.py", line 6, in func_to_trace
if x.sum() > 0:
File "pytorch/torch/fx/proxy.py", line 155, in __bool__
return self.tracer.to_bool(self)
File "pytorch/torch/fx/proxy.py", line 85, in to_bool
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""
到 if 語句的條件依賴於 x.sum() 的值,而 x.sum() 又依賴於函式輸入 x 的值。由於 x 可以改變(即,如果您將新的輸入張量傳遞給追蹤的函式),這就是動態控制流。回溯會沿著您的程式碼向上追溯,以顯示這種情況發生的位置。
靜態控制流#
另一方面,所謂的靜態控制流是受支援的。靜態控制流是迴圈或 if 語句,其值在呼叫之間不會改變。通常,在 PyTorch 程式中,這種控制流出現在根據超引數對模型的架構做出決策的程式碼中。作為一個具體的例子:
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self, do_activation : bool = False):
super().__init__()
self.do_activation = do_activation
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
x = self.linear(x)
# This if-statement is so-called static control flow.
# Its condition does not depend on any input values
if self.do_activation:
x = torch.relu(x)
return x
without_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)
traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):
linear_1 = self.linear(x); x = None
return linear_1
"""
traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):
linear_1 = self.linear(x); x = None
relu_1 = torch.relu(linear_1); linear_1 = None
return relu_1
"""
if 語句 if self.do_activation 不依賴於任何函式輸入,因此它是靜態的。 do_activation 可以被認為是超引數,並且具有不同該引數值的 MyModule 例項的不同追蹤會產生不同的程式碼。這是一個有效的模式,並且受到符號追蹤的支援。
許多動態控制流的例項實際上是靜態控制流。這些例項可以透過消除對輸入值的依賴關係來支援符號追蹤,例如,透過將值移到 Module 屬性中,或在符號追蹤期間將具體值繫結到引數。
def f(x, flag):
if flag: return x
else: return x*2
fx.symbolic_trace(f) # Fails!
fx.symbolic_trace(f, concrete_args={'flag': True})
對於真正的動態控制流,包含這些程式碼的程式部分可以被追蹤為對 Method(參見 使用 Tracer 類定製追蹤)或函式(參見 wrap())的呼叫,而不是追蹤其內部。
非 torch 函式#
FX 使用 __torch_function__ 作為它攔截呼叫的機制(有關更多資訊,請參閱 技術概述)。某些函式,例如內建 Python 函式或 math 模組中的函式,不受 __torch_function__ 的覆蓋,但我們仍然希望在符號追蹤中捕獲它們。例如:
import torch
import torch.fx
from math import sqrt
def normalize(x):
"""
Normalize `x` by the size of the batch dimension
"""
return x / sqrt(len(x))
# It's valid Python code
normalize(torch.rand(3, 4))
traced = torch.fx.symbolic_trace(normalize)
"""
<...>
File "sqrt.py", line 9, in normalize
return x / sqrt(len(x))
File "pytorch/torch/fx/proxy.py", line 161, in __len__
raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""
錯誤告訴我們內建函式 len 不受支援。我們可以使用 wrap() API 使諸如 len 之類的函式在追蹤中被記錄為直接呼叫。
torch.fx.wrap('len')
torch.fx.wrap('sqrt')
traced = torch.fx.symbolic_trace(normalize)
print(traced.code)
"""
import math
def forward(self, x):
len_1 = len(x)
sqrt_1 = math.sqrt(len_1); len_1 = None
truediv = x / sqrt_1; x = sqrt_1 = None
return truediv
"""
使用 Tracer 類定製追蹤#
Tracer 類是 symbolic_trace 實現的基礎。透過繼承 Tracer 可以自定義追蹤行為,如下所示:
class MyCustomTracer(torch.fx.Tracer):
# Inside here you can override various methods
# to customize tracing. See the `Tracer` API
# reference
pass
# Let's use this custom tracer to trace through this module
class MyModule(torch.nn.Module):
def forward(self, x):
return torch.relu(x) + torch.ones(3, 4)
mod = MyModule()
traced_graph = MyCustomTracer().trace(mod)
# trace() returns a Graph. Let's wrap it up in a
# GraphModule to make it runnable
traced = torch.fx.GraphModule(mod, traced_graph)
葉子模組#
葉子模組是出現在符號追蹤中的模組,而不是被追蹤穿透的模組。預設的葉子模組集是標準的 torch.nn 模組例項集。例如:
class MySpecialSubmodule(torch.nn.Module):
def forward(self, x):
return torch.neg(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 4)
self.submod = MySpecialSubmodule()
def forward(self, x):
return self.submod(self.linear(x))
traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` is preserved as a call, yet `submod` is traced though.
# This is because the default set of "Leaf Modules" includes all
# standard `torch.nn` modules.
"""
import torch
def forward(self, x):
linear_1 = self.linear(x); x = None
neg_1 = torch.neg(linear_1); linear_1 = None
return neg_1
"""
透過覆蓋 Tracer.is_leaf_module() 可以自定義葉子模組集。
雜項#
張量建構函式(例如
torch.zeros、torch.ones、torch.rand、torch.randn、torch.sparse_coo_tensor)目前不可追蹤。確定性建構函式(
zeros、ones)可以使用,並且它們生成的值將作為常量嵌入到追蹤中。如果這些建構函式的引數引用了動態輸入大小,這才是有問題的。在這種情況下,ones_like或zeros_like可能是可行的替代方案。非確定性建構函式(
rand、randn)將嵌入一個隨機值到追蹤中。這很可能不是預期的行為。一種變通方法是將torch.randn包裝在torch.fx.wrap函式中,然後呼叫它。
@torch.fx.wrap def torch_randn(x, shape): return torch.randn(shape) def f(x): return x + torch_randn(x, 5) fx.symbolic_trace(f)
此行為可能在將來的版本中修復。
型別註解
Python 3 風格的型別註解(例如
func(x : torch.Tensor, y : int) -> torch.Tensor)是受支援的,並且將由符號追蹤保留。Python 2 風格的註釋型別註解
# type: (torch.Tensor, int) -> torch.Tensor目前不受支援。函式內區域性名稱的註解目前不受支援。
關於
training標誌和子模組的注意事項在使用
torch.nn.functional.dropout等函式時,通常會將 training 引數作為self.training傳入。在 FX 追蹤期間,這很可能會作為常量值被烘焙進去。
import torch import torch.fx class DropoutRepro(torch.nn.Module): def forward(self, x): return torch.nn.functional.dropout(x, training=self.training) traced = torch.fx.symbolic_trace(DropoutRepro()) print(traced.code) """ def forward(self, x): dropout = torch.nn.functional.dropout(x, p = 0.5, training = True, inplace = False); x = None return dropout """ traced.eval() x = torch.randn(5, 3) torch.testing.assert_close(traced(x), x) """ AssertionError: Tensor-likes are not close! Mismatched elements: 15 / 15 (100.0%) Greatest absolute difference: 1.6207983493804932 at index (0, 2) (up to 1e-05 allowed) Greatest relative difference: 1.0 at index (0, 0) (up to 0.0001 allowed) """
但是,當使用標準的
nn.Dropout()子模組時,training 標誌是封裝的,並且—由於nn.Module物件模型的保留—可以更改。
class DropoutRepro2(torch.nn.Module): def __init__(self): super().__init__() self.drop = torch.nn.Dropout() def forward(self, x): return self.drop(x) traced = torch.fx.symbolic_trace(DropoutRepro2()) print(traced.code) """ def forward(self, x): drop = self.drop(x); x = None return drop """ traced.eval() x = torch.randn(5, 3) torch.testing.assert_close(traced(x), x)
因此,請考慮將與
training標誌動態互動的模組標記為葉子模組。
API 參考#
- torch.fx.symbolic_trace(root, concrete_args=None)[source]#
符號追蹤 API
給定一個
nn.Module或函式例項root,此函式將返回一個GraphModule,該GraphModule透過記錄追蹤root時看到的_操作來構建。concrete_args允許您部分特化函式,無論是為了消除控制流還是資料結構。例如
def f(a, b): if b == True: return a else: return a * 2
FX 通常無法追蹤此內容,因為其中存在控制流。但是,我們可以使用concrete_args 來特化b 的值以追蹤此內容。
f = fx.symbolic_trace(f, concrete_args={"b": False}) assert f(3, False) == 6
請注意,儘管您仍然可以傳入 b 的不同值,但它們將被忽略。
我們還可以使用concrete_args 來消除函式中的資料結構處理。這將使用 pytrees 來展平您的輸入。為避免過度特化,請為不應特化的值傳入 fx.PH。例如:
def f(x): out = 0 for v in x.values(): out += v return out f = fx.symbolic_trace( f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}} ) assert f({"a": 1, "b": 2, "c": 4}) == 7
- 引數
root (Union[torch.nn.Module, Callable]) – 要追蹤並轉換為 Graph 表示的模組或函式。
concrete_args (Optional[Dict[str, any]]) – 要部分特化的輸入
- 返回
一個由
root記錄的操作構建的模組。- 返回型別
注意
此 API 的向後相容性已得到保證。
- torch.fx.wrap(fn_or_name)[source]#
此函式可以在模組級別作用域呼叫,以將 fn_or_name 註冊為“葉子函式”。“葉子函式”將在 FX 追蹤中被保留為 CallFunction 節點,而不是被追蹤穿透。
# foo/bar/baz.py def my_custom_function(x, y): return x * x + y * y torch.fx.wrap("my_custom_function") def fn_to_be_traced(x, y): # When symbolic tracing, the below call to my_custom_function will be inserted into # the graph rather than tracing it. return my_custom_function(x, y)
此函式也可以用作裝飾器
# foo/bar/baz.py @torch.fx.wrap def my_custom_function(x, y): return x * x + y * y
包裝後的函式可以被認為是“葉子函式”,類似於“葉子模組”的概念,即它們是在 FX 追蹤中保留為呼叫的函式,而不是被追蹤穿透。
- 引數
fn_or_name (Union[str, Callable]) – 當呼叫函式或全域性函式名時要插入圖中的函式或全域性函式名。
注意
此 API 的向後相容性已得到保證。
- class torch.fx.GraphModule(*args, **kwargs)[source]#
GraphModule 是從 fx.Graph 生成的 nn.Module。Graphmodule 具有
graph屬性,以及從該graph生成的code和forward屬性。警告
當
graph被重新賦值時,code和forward將被自動重新生成。但是,如果您在未重新賦值graph屬性本身的情況下編輯graph的內容,則必須呼叫recompile()來更新生成的程式碼。注意
此 API 的向後相容性已得到保證。
- __init__(root, graph, class_name='GraphModule')[source]#
構造一個 GraphModule。
- 引數
root (Union[torch.nn.Module, Dict[str, Any]) –
root可以是 nn.Module 例項,也可以是對映字串到任何屬性型別的 Dict。如果root是 Module,那麼 Graph 的 Nodes 的target欄位中對 Module 相關物件的引用(透過限定名)將從root的 Module 層次結構中的相應位置複製到 GraphModule 的模組層次結構中。如果root是 dict,那麼在 Node 的target中找到的限定名將直接在 dict 的鍵中查詢。由 Dict 對映的物件將被複制到 GraphModule 的模組層次結構中的相應位置。graph (Graph) –
graph包含此 GraphModule 應使用的節點來生成程式碼。class_name (str) –
name表示此 GraphModule 的名稱,用於除錯目的。如果未設定,所有錯誤訊息都將報告為源自GraphModule。將其設定為root的原始名稱或在您的轉換上下文中具有意義的名稱可能很有用。
注意
此 API 的向後相容性已得到保證。
- add_submodule(target, m)[source]#
將給定的子模組新增到
self。如果
target的子路徑尚不存在,則會在此安裝空模組。- 引數
- 返回
- 子模組是否可以被插入。為了
使此方法返回 True,由
target表示的鏈中的每個物件必須 a) 尚不存在,或 b) 引用一個nn.Module(而不是引數或其他屬性)。
- 返回型別
注意
此 API 的向後相容性已得到保證。
- delete_all_unused_submodules()[source]#
從
self中刪除所有未使用的子模組。當滿足以下任一條件時,模組被認為是“已使用”:1. 它有已使用的子模組 2. 它的 forward 被直接透過
call_module節點呼叫 3. 它有一個非 Module 屬性,該屬性從get_attr節點使用。此方法可用於清理
nn.Module,而無需手動對每個未使用的子模組呼叫delete_submodule。注意
此 API 的向後相容性已得到保證。
- delete_submodule(target)[source]#
從
self中刪除給定的子模組。如果
target不是有效目標,則不會刪除該模組。- 引數
target (str) – 新子模組的完全限定名稱(請參閱
nn.Module.get_submodule中的示例,瞭解如何指定完全限定名稱)。- 返回
- 目標字串是否引用了
我們要刪除的子模組。返回值為
False意味著target不是對子模組的有效引用。
- 返回型別
注意
此 API 的向後相容性已得到保證。
- print_readable(print_output=True, include_stride=False, include_device=False, colored=False, *, fast_sympy_print=False, expanded_def=False)[source]#
返回當前 GraphModule 及其子 GraphModule 生成的 Python 程式碼。
警告
此 API 是實驗性的,並且不向後相容。
- class torch.fx.Graph(owning_module=None, tracer_cls=None, tracer_extras=None)[source]#
Graph是 FX 中間表示使用的主要資料結構。它由一系列Node組成,每個Node代表一個呼叫點(或其他語法結構)。Node的列表共同構成一個有效的 Python 函式。例如,以下程式碼:
import torch import torch.fx class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return torch.topk( torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3 ) m = MyModule() gm = torch.fx.symbolic_trace(m)
將產生以下圖:
print(gm.graph)
graph(x): %linear_weight : [num_users=1] = self.linear.weight %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) return topk_1有關
Graph中操作語義的資訊,請參閱Node。注意
此 API 的向後相容性已得到保證。
- __init__(owning_module=None, tracer_cls=None, tracer_extras=None)[source]#
構造一個空的 Graph。
注意
此 API 的向後相容性已得到保證。
- call_function(the_function, args=None, kwargs=None, type_expr=None, name=None)[source]#
將
call_functionNode插入到Graph中。call_function節點表示對由the_function指定的 Python 可呼叫物件的呼叫。- 引數
the_function (Callable[..., Any]) – 要呼叫的函式。可以是任何 PyTorch 運算子、Python 函式或
builtins或operator名稱空間中的成員。args (Optional[Tuple[Argument, ...]]) – 要傳遞給被呼叫函式的 positional 引數。
kwargs (Optional[Dict[str, Argument]]) – 要傳遞給被呼叫函式的 keyword 引數。
type_expr (Optional[Any]) – 一個可選的型別註解,表示此節點輸出將具有的 Python 型別。
name (Optional[str]) – 節點的名稱。如果未指定,則設定為 None。
- 返回
新建立並插入的
call_function節點。- 返回型別
注意
此方法的插入點和型別表示式規則與
Graph.create_node()相同。注意
此 API 的向後相容性已得到保證。
- call_method(method_name, args=None, kwargs=None, type_expr=None)[source]#
將
call_methodNode插入到Graph中。call_method節點表示對args的第 0 個元素上的給定方法進行的呼叫。- 引數
method_name (str) – 要應用於 self 引數的方法名稱。例如,如果 args[0] 是表示
Tensor的Node,那麼要在該Tensor上呼叫relu(),請將relu傳遞給method_name。args (Optional[Tuple[Argument, ...]]) – 要傳遞給被呼叫方法的 positional 引數。請注意,這應該包含一個
self引數。kwargs (Optional[Dict[str, Argument]]) – 要傳遞給被呼叫方法的 keyword 引數。
type_expr (Optional[Any]) – 一個可選的型別註解,表示此節點輸出將具有的 Python 型別。
- 返回
新建立並插入的
call_method節點。- 返回型別
注意
此方法的插入點和型別表示式規則與
Graph.create_node()相同。注意
此 API 的向後相容性已得到保證。
- call_module(module_name, args=None, kwargs=None, type_expr=None)[source]#
將一個
call_moduleNode插入到Graph中。call_module節點表示在Module層次結構中呼叫Module的 forward() 函式。- 引數
module_name (str) – 要呼叫的
Module在Module層次結構中的限定名稱。例如,如果跟蹤的Module包含一個名為foo的子模組,該子模組又包含一個名為bar的子模組,則應將限定名稱foo.bar作為module_name傳遞,以呼叫該模組。args (Optional[Tuple[Argument, ...]]) – 要傳遞給被呼叫方法的 positional 引數。請注意,這不應包含
self引數。kwargs (Optional[Dict[str, Argument]]) – 要傳遞給被呼叫方法的 keyword 引數。
type_expr (Optional[Any]) – 一個可選的型別註解,表示此節點輸出將具有的 Python 型別。
- 返回
新建立並插入的
call_module節點。- 返回型別
注意
此方法的插入點和型別表示式規則與
Graph.create_node()相同。注意
此 API 的向後相容性已得到保證。
- create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)[source]#
建立一個
Node並將其插入到當前插入點的Graph中。請注意,當前插入點可以透過Graph.inserting_before()和Graph.inserting_after()來設定。- 引數
op (str) – 此 Node 的操作碼。可以是 ‘call_function’、‘call_method’、‘get_attr’、‘call_module’、‘placeholder’ 或 ‘output’。這些操作碼的語義在
Graph文件字串中有所描述。args (Optional[Tuple[Argument, ...]]) – 是此節點的引數元組。
kwargs (Optional[Dict[str, Argument]]) – 此 Node 的 kwargs。
name (Optional[str]) –
Node的可選字串名稱。這會影響在生成的 Python 程式碼中分配給該值的名稱。type_expr (Optional[Any]) – 一個可選的型別註解,表示此節點輸出將具有的 Python 型別。
- 返回
新建立並插入的節點。
- 返回型別
注意
此 API 的向後相容性已得到保證。
- eliminate_dead_code(is_impure_node=None)[source]#
從圖中刪除所有死程式碼,基於每個節點的 use 數量以及節點是否具有任何 side effects。呼叫前必須對圖進行拓撲排序。
- 引數
is_impure_node (Optional[Callable[[Node], bool]]) – 一個返回
None (節點是否為 impure。如果是) –
to (則預設行為是) –
Node.is_impure. (使用) –
- 返回
圖是否因該過程而改變。
- 返回型別
示例
在消除死程式碼之前,下面 a = x + 1 中的 a 沒有 use,因此可以從圖中消除而不產生任何影響。
def forward(self, x): a = x + 1 return x + self.attr_1
消除死程式碼後,a = x + 1 已被刪除,其餘的 forward 保持不變。
def forward(self, x): return x + self.attr_1
警告
死程式碼消除有一些啟發式方法來避免刪除具有 side effects 的節點 (參見 Node.is_impure),但總的來說覆蓋率非常差,因此您應該假定呼叫此方法不是安全的,除非您知道您的 FX 圖完全由函式式操作組成,或者您提供了自己的自定義函式來檢測具有 side effects 的節點。
注意
此 API 的向後相容性已得到保證。
- erase_node(to_erase)[source]#
從
Graph中刪除一個Node。如果該節點在Graph中仍有 use,則會引發異常。- 引數
to_erase (Node) – 要從
Graph中刪除的Node。
注意
此 API 的向後相容性已得到保證。
- find_nodes(*, op, target=None, sort=True)[source]#
允許快速查詢節點。
- 引數
op (str) – 操作的名稱。
target (Optional[Target]) – 節點的 target。對於 call_function,target 是必需的。對於其他 op,target 是可選的。
sort (bool) – 是否按節點在圖上出現的順序返回節點。
- 返回
具有請求的操作和 target 的節點的可迭代物件。
警告
此 API 是實驗性的,並且不向後相容。
- get_attr(qualified_name, type_expr=None)[source]#
將一個
get_attr節點插入到 Graph 中。get_attrNode表示從Module層次結構中獲取一個屬性。- 引數
qualified_name (str) – 要檢索的屬性的完全限定名稱。例如,如果跟蹤的 Module 包含一個名為
foo的子模組,該子模組包含一個名為bar的子模組,該子模組有一個名為baz的屬性,則應將限定名稱foo.bar.baz作為qualified_name傳遞。type_expr (Optional[Any]) – 一個可選的型別註解,表示此節點輸出將具有的 Python 型別。
- 返回
新建立並插入的
get_attr節點。- 返回型別
注意
此方法與
Graph.create_node具有相同的插入點和型別表示式規則。注意
此 API 的向後相容性已得到保證。
- graph_copy(g, val_map, return_output_node=False)[source]#
將給定圖中的所有節點複製到
self中。- 引數
g (Graph) – 要從中複製 Nodes 的源圖。
val_map (Dict[Node, Node]) – 一個字典,將填充從
g中的節點到self中的節點的對映。請注意,可以傳入帶有值的val_map來覆蓋某些值的複製。
- 返回
self中與g中的輸出值等效的值,如果g有一個output節點。否則為None。- 返回型別
Optional[Union[tuple[‘Argument’, …], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]
注意
此 API 的向後相容性已得到保證。
- inserting_after(n=None)[source]#
- 設定 create_node 和伴隨方法將插入到圖中的點。
在 `with` 語句中使用時,這將臨時設定插入點,然後在 `with` 語句退出時將其恢復。
with g.inserting_after(n): ... # inserting after node n ... # insert point restored to what it was previously g.inserting_after(n) # set the insert point permanently
引數 (Args)
- n (Optional[Node]): 在其之前插入的節點。如果為 None,則將在圖的開頭之後插入。
整個圖的開頭。
- 返回
一個資源管理器,它將在 `__exit__` 時恢復插入點。
注意
此 API 的向後相容性已得到保證。
- inserting_before(n=None)[source]#
- 設定 create_node 和伴隨方法將插入到圖中的點。
在 `with` 語句中使用時,這將臨時設定插入點,然後在 `with` 語句退出時將其恢復。
with g.inserting_before(n): ... # inserting before node n ... # insert point restored to what it was previously g.inserting_before(n) # set the insert point permanently
引數 (Args)
- n (Optional[Node]): 在其之前插入的節點。如果為 None,則將在圖的開頭之前插入。
整個圖的開頭。
- 返回
一個資源管理器,它將在 `__exit__` 時恢復插入點。
注意
此 API 的向後相容性已得到保證。
- lint()[source]#
執行此 Graph 的各種檢查,以確保其結構良好。特別是: - 檢查 Nodes 是否具有正確的歸屬(屬於此圖) - 檢查 Nodes 是否按拓撲順序出現 - 如果此 Graph 具有擁有它的 GraphModule,則檢查 target 是否在該 GraphModule 中存在。
注意
此 API 的向後相容性已得到保證。
- node_copy(node, arg_transform=<function Graph.<lambda>>)[source]#
將一個節點從一個圖複製到另一個圖。`arg_transform` 需要將引數從 node 的圖轉換為 self 的圖。示例
# Copying all the nodes in `g` into `new_graph` g: torch.fx.Graph = ... new_graph = torch.fx.graph() value_remap = {} for node in g.nodes: value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n])
- 引數
node (Node) – 要複製到
self中的節點。arg_transform (Callable[[Node], Argument]) – 一個函式,用於將 node 的 `args` 和 `kwargs` 中的 `Node` 引數轉換為 `self` 中等效的引數。在最簡單的情況下,這應該從一個將原始圖中的 Nodes 對映到 `self` 的表中檢索一個值。
- 返回型別
注意
此 API 的向後相容性已得到保證。
- property nodes: list['Node']#
獲取構成此 Graph 的 Nodes 列表。
注意,這個 `Node` 列表表示是一個雙向連結串列。在迭代過程中進行修改(例如,刪除一個 Node,新增一個 Node)是安全的。
- 返回
Nodes 的雙向連結串列。注意,可以對此列表呼叫 `reversed` 來切換迭代順序。
- on_generate_code(make_transformer)[source]#
在生成 Python 程式碼時註冊一個 transformer 函式。
- 引數 (Args)
- make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc])
一個返回程式碼 transformer 的函式,用於註冊。此函式由 `on_generate_code` 呼叫以獲取程式碼 transformer。
此函式也作為輸入提供當前註冊的程式碼 transformer(如果未註冊,則為 None),以防不希望覆蓋它。這對於連結程式碼 transformer 很有用。
- 返回
一個上下文管理器,當在 `with` 語句中使用時,會自動恢復先前註冊的程式碼 transformer。
示例
gm: fx.GraphModule = ... # This is a code transformer we want to register. This code # transformer prepends a pdb import and trace statement at the very # beginning of the generated torch.fx code to allow for manual # debugging with the PDB library. def insert_pdb(body): return ["import pdb; pdb.set_trace()\n", *body] # Registers `insert_pdb`, and overwrites the current registered # code transformer (given by `_` to the lambda): gm.graph.on_generate_code(lambda _: insert_pdb) # Or alternatively, registers a code transformer which first # runs `body` through existing registered transformer, then # through `insert_pdb`: gm.graph.on_generate_code( lambda current_trans: ( lambda body: insert_pdb( current_trans(body) if current_trans else body ) ) ) gm.recompile() gm(*inputs) # drops into pdb
此函式還可以用作上下文管理器,其優點是可以自動恢復先前註冊的程式碼 transformer。
# ... continue from previous example with gm.graph.on_generate_code(lambda _: insert_pdb): # do more stuff with `gm`... gm.recompile() gm(*inputs) # drops into pdb # now previous code transformer is restored (but `gm`'s code with pdb # remains - that means you can run `gm` with pdb here too, until you # run next `recompile()`).
警告
此 API 是實驗性的,並且不向後相容。
- output(result, type_expr=None)[source]#
將一個 `output`
Node插入到Graph中。`output` 節點表示 Python 程式碼中的 `return` 語句。`result` 是應返回的值。- 引數
result (Argument) – 要返回的值。
type_expr (Optional[Any]) – 一個可選的型別註解,表示此節點輸出將具有的 Python 型別。
注意
此方法與
Graph.create_node具有相同的插入點和型別表示式規則。注意
此 API 的向後相容性已得到保證。
- placeholder(name, type_expr=None, default_value)[source]#
將一個 `placeholder` 節點插入到 Graph 中。`placeholder` 表示函式輸入。
- 引數
name (str) – 輸入值的名稱。這對應於此 `Graph` 所代表的函式的 positional 引數的名稱。
type_expr (Optional[Any]) – 一個可選的型別註解,表示此節點輸出的 Python 型別。在某些情況下,這對於正確的程式碼生成是必需的(例如,當函式隨後在 TorchScript 編譯中使用時)。
default_value (Any) – 此函式引數應採用的預設值。注意:為了允許將 `None` 作為預設值,應將 `inspect.Signature.empty` 作為此引數傳遞,以指定該引數*沒有*預設值。
- 返回型別
注意
此方法與
Graph.create_node具有相同的插入點和型別表示式規則。注意
此 API 的向後相容性已得到保證。
- python_code(root_module, *, verbose=False, include_stride=False, include_device=False, colored=False, expanded_def=False)[source]#
將此 `Graph` 轉換為有效的 Python 程式碼。
- 引數
root_module (str) – 用於查詢限定名稱 target 的根模組的名稱。通常是 ‘self’。
- 返回
src:表示物件的 Python 原始碼 globals:`src` 中的全域性名稱字典 -> 它們引用的物件。
- 返回型別
一個 PythonCode 物件,包含兩個欄位。
注意
此 API 的向後相容性已得到保證。
- class torch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)[source]#
`Node` 是表示 `Graph` 中單個操作的資料結構。在大多數情況下,Nodes 表示對各種實體的呼叫,例如運算子、方法和 Modules(一些例外包括指定函式輸入和輸出的節點)。每個 `Node` 都有一個由其 `op` 屬性指定的函式。`op` 每個值的 `Node` 語義如下:
`placeholder` 表示函式輸入。`name` 屬性指定該值將採用的名稱。`target` 同樣是引數的名稱。`args` 持有:1) 空,或 2) 一個表示函式輸入的預設引數的單個引數。`kwargs` 是不關心的。
`get_attr` 從模組層次結構中檢索引數。`name` 同樣是分配給獲取結果的名稱。`target` 是引數在模組層次結構中的位置的完全限定名稱。`args` 和 `kwargs` 是不關心的。
`call_function` 將一個自由函式應用於某些值。`name` 同樣是分配給該值的名稱。`target` 是要應用的函式。`args` 和 `kwargs` 表示函式的引數,遵循 Python 呼叫約定。
`call_module` 將模組層次結構中的模組的 `forward()` 方法應用於給定引數。`name` 同上。`target` 是要呼叫的模組在模組層次結構中的完全限定名稱。`args` 和 `kwargs` 表示呼叫模組的引數,*不包括 self 引數*。
`call_method` 呼叫一個值上的方法。`name` 同上。`target` 是應用於 `self` 引數的方法的字串名稱。`args` 和 `kwargs` 表示呼叫模組的引數,*包括 self 引數*。
`output` 在其 `args[0]` 屬性中包含跟蹤函式的輸出。這對應於 Graph 列印輸出中的“return”語句。
注意
此 API 的向後相容性已得到保證。
- property all_input_nodes: list['Node']#
返回作為此 Node 輸入的所有 Nodes。這相當於迭代 `args` 和 `kwargs` 並僅收集是 Nodes 的值。
- 返回
出現在此 `Node` 的 `args` 和 `kwargs` 中的 `Nodes` 列表,按該順序。
- append(x)[source]#
在圖節點列表中在此節點之後插入 `x`。等同於 `self.next.prepend(x)`。
- 引數
x (Node) – 要放在此節點之後的節點。必須是同一圖的成員。
注意
此 API 的向後相容性已得到保證。
- property args: tuple[Union[tuple['Argument', ...], collections.abc.Sequence['Argument'], collections.abc.Mapping[str, 'Argument'], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType], ...]#
此 `Node` 的引數元組。引數的解釋取決於節點的 op。有關更多資訊,請參閱 Node 文件字串。
允許對此屬性進行賦值。所有 use 和 users 的計數都會在賦值時自動更新。
- format_node(placeholder_names=None, maybe_return_typename=None, *, include_tensor_metadata=False)[source]#
返回 `self` 的描述性字串表示。
此方法可作為除錯實用程式使用,無需引數。
此函式還在 `Graph` 的 `__str__` 方法中用作內部助手。`placeholder_names` 和 `maybe_return_typename` 中的字串共同構成了此 Graph 外部 GraphModule 中自動生成的 `forward` 函式的簽名。`placeholder_names` 和 `maybe_return_typename` 不得在其他地方使用。
- 引數
placeholder_names (Optional[list[str]]) – 一個列表,將儲存表示生成的 `forward` 函式中 placeholder 的格式化字串。僅內部使用。
maybe_return_typename (Optional[list[str]]) – 一個單元素列表,將儲存表示生成的 `forward` 函式輸出的格式化字串。僅內部使用。
include_tensor_metadata (bool) – 是否包含張量元資料。
- 返回
- 如果 1) 我們正在使用 `format_node` 作為內部助手
在 `Graph` 的 `__str__` 方法中,並且 2) `self` 是一個 placeholder Node,則返回 `None`。否則,返回當前 Node 的描述性字串表示。
- 返回型別
注意
此 API 的向後相容性已得到保證。
- insert_arg(idx, arg)[source]#
將一個 positional 引數插入到具有給定索引的引數列表中。
- 引數
idx (int) – 要插入之前的 `self.args` 中的元素的索引。
arg (Argument) – 要插入 `args` 中的新引數值。
注意
此 API 的向後相容性已得到保證。
- is_impure(impure_random=True)[source]#
返回此 op 是否為 impure,即如果其 op 是 placeholder 或 output,或者是一個 impure 的 call_function 或 call_module。
- 引數
impure_random (bool) – 是否將 rand op 視為 impure。
- 返回
op 是 impure 還是不是。
- 返回型別
警告
此 API 是實驗性的,並且不向後相容。
- property kwargs: dict[str, Union[tuple['Argument', ...], collections.abc.Sequence['Argument'], collections.abc.Mapping[str, 'Argument'], slice, range, torch.fx.node.Node, str, int, float, bool, complex, torch.dtype, torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, torch.SymInt, torch.SymBool, torch.SymFloat, NoneType]]#
此 `Node` 的關鍵字引數字典。引數的解釋取決於節點的 op。有關更多資訊,請參閱 Node 文件字串。
允許對此屬性進行賦值。所有 use 和 users 的計數都會在賦值時自動更新。
- normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)[source]#
返回 Python target 的標準化引數。這意味著 `args/kwargs` 將與 module/functional 的簽名匹配,並且如果 `normalize_to_only_use_kwargs` 為 true,則將僅按位置返回 kwargs。還將填充預設值。不支援僅 positional 引數或 varargs 引數。
支援模組呼叫。
可能需要 `arg_types` 和 `kwarg_types` 來區分過載。
- 引數
root (torch.nn.Module) – 用於解析模組 target 的模組。
arg_types (Optional[Tuple[Any]]) – arg 的 arg 型別元組。
kwarg_types (Optional[Dict[str, Any]]) – kwargs 的 arg 型別字典。
normalize_to_only_use_kwargs (bool) – 是否規範化為僅使用 kwargs。
- 返回
返回 NamedTuple ArgsKwargsPair,如果未成功則返回 `None`。
- 返回型別
Optional[ArgsKwargsPair]
警告
此 API 是實驗性的,並且不向後相容。
- prepend(x)[source]#
在圖節點列表中在此節點之前插入 x。示例
Before: p -> self bx -> x -> ax After: p -> x -> self bx -> ax
- 引數
x (Node) – 要放在此節點之前的節點。必須是同一圖的成員。
注意
此 API 的向後相容性已得到保證。
- replace_all_uses_with(replace_with, delete_user_cb=<function Node.<lambda>>, *, propagate_meta=False)[source]#
將 Graph 中 `self` 的所有 use 替換為 Node `replace_with`。
- 引數
replace_with (Node) – 用於替換 `self` 的所有 use 的節點。
delete_user_cb (Callable) – 用於確定是否應刪除 self 節點的給定 use 的回撥函式。
propagate_meta (bool) – 是否將原始節點的 .meta 欄位中的所有屬性複製到替換節點。為安全起見,僅當替換節點尚不具有 .meta 欄位時才有效。
- 返回
在此更改被執行的 Nodes 列表。
- 返回型別
list[‘Node’]
注意
此 API 的向後相容性已得到保證。
- replace_input_with(old_input, new_input)[source]#
遍歷 `self` 的輸入節點,並將 `old_input` 的所有例項替換為 `new_input`。
- 引數
old_input (Node) – 要被替換的舊輸入節點。
new_input (Node) – 用於替換 `old_input` 的新輸入節點。
注意
此 API 的向後相容性已得到保證。
- property stack_trace: Optional[str]#
返回跟蹤期間記錄的 Python 堆疊跟蹤(如果有)。當使用 fx.Tracer 進行跟蹤時,此屬性通常由 `Tracer.create_proxy` 填充。要在跟蹤期間記錄堆疊跟蹤以進行除錯,請在 `Tracer` 例項上設定 `record_stack_traces = True`。當使用 dynamo 進行跟蹤時,此屬性將由 `OutputGraph.create_proxy` 預設填充。
stack_trace 的字串末尾將是最內層的幀。
- class torch.fx.Tracer(autowrap_modules=(math,), autowrap_functions=())[source]#
`Tracer` 是實現 `torch.fx.symbolic_trace` 的符號跟蹤功能的類。呼叫 `symbolic_trace(m)` 等同於 `Tracer().trace(m)`。
Tracer 可以被子類化以覆蓋跟蹤過程的各種行為。可以覆蓋的不同行為在此類方法的文件字串中進行了描述。
注意
此 API 的向後相容性已得到保證。
- call_module(m, forward, args, kwargs)[source]#
指定此 `Tracer` 在遇到對 `nn.Module` 例項的呼叫時行為的方法。
預設情況下,行為是檢查被呼叫的模組是否為葉子模組(透過 `is_leaf_module`)。如果是,則在 `Graph` 中發出一個指向 `m` 的 `call_module` 節點。否則,正常呼叫 `Module`,跟蹤其 `forward` 函式中的操作。
可以覆蓋此方法來實現,例如,建立巢狀的跟蹤 GraphModules,或在跟蹤跨越 `Module` 邊界時執行任何其他所需行為。
- 引數
m (Module) – 正在發出呼叫的模組。
forward (Callable) – 要呼叫的 `Module` 的 forward() 方法。
args (Tuple) – 模組呼叫處的 args。
kwargs (Dict) – 模組呼叫處的 kwargs。
- 返回
來自 Module 呼叫的返回值。在發出 `call_module` 節點的情況下,這是一個 `Proxy` 值。否則,它是從 `Module` 呼叫返回的任何值。
- 返回型別
注意
此 API 的向後相容性已得到保證。
- create_arg(a)[source]#
用於指定此 `Tracer` 在準備作為 `Graph` 中節點引數的值時的行為的方法。
預設行為包括:
迭代集合型別(例如,tuple、list、dict),並遞迴呼叫 `create_args` 來處理元素。
給定 Proxy 物件,返回對底層 IR `Node` 的引用。
給定非 Proxy Tensor 物件,為各種情況發出 IR。
對於 Parameter,發出指向該 Parameter 的 `get_attr` 節點。
對於非 Parameter Tensor,將該 Tensor 儲存在一個特殊的屬性中,該屬性指向該屬性。
可以覆蓋此方法以支援更多型別。
- 引數
a (Any) – 將作為 `Argument` 發出到 `Graph` 中的值。
- 返回
`a` 轉換為適當 `Argument` 的值。
- 返回型別
Argument
注意
此 API 的向後相容性已得到保證。
- create_args_for_root(root_fn, is_module, concrete_args=None)[source]#
建立與根模組(`root` Module)的簽名對應的 `placeholder` 節點。此方法內省 root 的簽名並相應地發出這些節點,還支援 `*args` 和 `**kwargs`。
警告
此 API 是實驗性的,並且不向後相容。
- create_node(kind, target, args, kwargs, name=None, type_expr=None)[source]#
根據目標、引數、關鍵字引數和名稱插入圖節點。
可以覆蓋此方法以執行額外的檢查、驗證或修改用於節點建立的值。例如,有人可能希望不允許記錄就地操作。
注意
此 API 的向後相容性已得到保證。
- 返回型別
- create_proxy(kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None)[source]#
根據給定的引數建立 Node,然後將 Node 包裝在 Proxy 物件中返回。
如果 kind = ‘placeholder’,那麼我們正在建立一個表示函式引數的 Node。如果需要編碼預設引數,我們使用 `args` 元組。對於 `placeholder` Nodes,`args` 否則為空。
注意
此 API 的向後相容性已得到保證。
- getattr(attr, attr_val, parameter_proxy_cache)[source]#
指定此 `Tracer` 在呼叫 `nn.Module` 例項上的 getattr 時行為的方法。
預設情況下,行為是返回屬性的代理值。它還將代理值儲存在 `parameter_proxy_cache` 中,以便將來的呼叫將重用該代理而不是建立新代理。
可以覆蓋此方法來實現,例如,在查詢引數時不返回代理。
- 引數
attr (str) – 查詢的屬性名稱。
attr_val (Any) – 屬性的值。
parameter_proxy_cache (Dict[str, Any]) – 屬性名到代理的快取。
- 返回
來自 getattr 呼叫的返回值。
警告
此 API 是實驗性的,並且不向後相容。
- is_leaf_module(m, module_qualified_name)[source]#
指定給定 `nn.Module` 是否為“葉子”模組的方法。
葉子模組是 IR 中出現的原子單元,由 `call_module` 呼叫引用。預設情況下,PyTorch 標準庫名稱空間(torch.nn)中的模組是葉子模組。所有其他模組都將被跟蹤並記錄其組成的 op,除非透過此引數另有指定。
- 引數
m (Module) – 查詢的模組。
module_qualified_name (str) – 到此模組根的路徑。例如,如果模組層次結構中的子模組 `foo` 包含子模組 `bar`,該子模組又包含子模組 `baz`,則該模組將在此處顯示為限定名稱 `foo.bar.baz`。
- 返回型別
注意
此 API 的向後相容性已得到保證。
- iter(obj)[source]#
- 在迭代代理物件時呼叫,例如
在控制流中使用時。通常我們不知道該怎麼做,因為我們不知道代理的值,但是自定義跟蹤器可以透過 create_node 將更多資訊附加到圖節點,並可以選擇返回一個迭代器。
注意
此 API 的向後相容性已得到保證。
- 返回型別
- keys(obj)[source]#
- 在代理物件呼叫 keys() 方法時呼叫。
這就是當對代理物件執行 ** 時發生的情況。這應該返回一個迭代器,如果 ** 在您的自定義跟蹤器中工作的話。
注意
此 API 的向後相容性已得到保證。
- 返回型別
- path_of_module(mod)[源]#
輔助方法,用於查詢
root的 Module 層次結構中mod的限定名稱。例如,如果root有一個名為foo的子模組,而foo又有一個名為bar的子模組,則將bar傳遞給此函式將返回字串 “foo.bar”。注意
此 API 的向後相容性已得到保證。
- to_bool(obj)[源]#
- 當代理物件被轉換為布林值時呼叫,例如
在控制流中使用時。通常我們不知道該怎麼做,因為我們不知道代理的值,但是自定義跟蹤器可以使用 create_node 向圖節點附加更多資訊,並可以選擇返回值。
注意
此 API 的向後相容性已得到保證。
- 返回型別
- class torch.fx.Proxy(node, tracer=None)[源]#
Proxy物件是Node的包裝器,它們在符號跟蹤期間在程式中流動,並將它們觸及的所有操作(torch函式呼叫、方法呼叫、運算子)記錄到不斷增長的 FX Graph 中。如果您正在進行圖變換,您可以將自己的
Proxy方法包裝在原始Node上,以便可以使用過載運算子將其他內容新增到Graph中。Proxy物件不能迭代。換句話說,如果Proxy在迴圈中使用或作為*args/**kwargs函式引數使用,符號跟蹤器將丟擲錯誤。通常有兩種方法可以解決此問題:1. 將無法跟蹤的邏輯提取到一個頂級函式中,並對其使用
fx.wrap。 2. 如果控制流是靜態的(即迴圈次數取決於某個超引數),則可以將程式碼保留在原始位置並重構為類似for i in range(self.some_hyperparameter): indexed_item = proxied_value[i]
有關 Proxy 內部機制的更詳細說明,請參閱 torch/fx/README.md 中的“Proxy”部分。
注意
此 API 的向後相容性已得到保證。
- class torch.fx.Interpreter(module, garbage_collect_values=True, graph=None)[源]#
Interpreter 按節點逐個執行 FX 圖。這種模式可用於許多用途,包括編寫程式碼轉換和分析傳遞。
可以覆蓋 Interpreter 類中的方法來定製執行行為。可覆蓋方法的呼叫層次圖
run() +-- run_node +-- placeholder() +-- get_attr() +-- call_function() +-- call_method() +-- call_module() +-- output()
示例
假設我們要將所有
torch.neg例項與torch.sigmoid相互替換(包括它們的Tensor方法等效項)。我們可以像這樣繼承 Interpreterclass NegSigmSwapInterpreter(Interpreter): def call_function( self, target: Target, args: Tuple, kwargs: Dict ) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(target, args, kwargs) def call_method(self, target: Target, args: Tuple, kwargs: Dict) -> Any: if target == "neg": call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(target, args, kwargs) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) input = torch.randn(3, 4) result = NegSigmSwapInterpreter(gm).run(input) torch.testing.assert_close(result, torch.neg(input).sigmoid())
- 引數
module (torch.nn.Module) – 要執行的模組
garbage_collect_values (bool) – 是否在模組執行期間刪除其最後一個使用的值。這可確保在執行過程中實現最佳記憶體使用。例如,可以透過檢視
Interpreter.env屬性來停用此項,以檢查執行中的所有中間值。graph (Optional[Graph]) – 如果已傳遞,直譯器將執行此圖而不是 module.graph,並使用提供的 module 引數來滿足任何狀態請求。
注意
此 API 的向後相容性已得到保證。
- boxed_run(args_list)[源]#
透過解釋執行 module 並返回結果。這使用了“boxed”呼叫約定,您需要傳遞一個引數列表,該列表將被直譯器清除。這確保輸入張量能被及時釋放。
注意
此 API 的向後相容性已得到保證。
- call_function(target, args, kwargs)[源]#
執行
call_function節點並返回結果。- 引數
target (Target) – 此節點的呼叫目標。有關語義的詳細資訊,請參閱 Node
args (Tuple) – 此呼叫的位置引數元組
kwargs (Dict) – 此呼叫的關鍵字引數字典
- 返回型別
- Return
Any: 函式呼叫返回的值
注意
此 API 的向後相容性已得到保證。
- call_method(target, args, kwargs)[源]#
執行
call_method節點並返回結果。- 引數
target (Target) – 此節點的呼叫目標。有關語義的詳細資訊,請參閱 Node
args (Tuple) – 此呼叫的位置引數元組
kwargs (Dict) – 此呼叫的關鍵字引數字典
- 返回型別
- Return
Any: 方法呼叫返回的值
注意
此 API 的向後相容性已得到保證。
- call_module(target, args, kwargs)[源]#
執行
call_module節點並返回結果。- 引數
target (Target) – 此節點的呼叫目標。有關語義的詳細資訊,請參閱 Node
args (Tuple) – 此呼叫的位置引數元組
kwargs (Dict) – 此呼叫的關鍵字引數字典
- 返回型別
- Return
Any: 模組呼叫返回的值
注意
此 API 的向後相容性已得到保證。
- fetch_args_kwargs_from_env(n)[源]#
從當前執行環境中獲取節點
n的args和kwargs的具體值。- 引數
n (Node) – 應從中獲取
args和kwargs的節點。- 返回
args和kwargs包含n的具體值。- 返回型別
Tuple[Tuple, Dict]
注意
此 API 的向後相容性已得到保證。
- fetch_attr(target)[源]#
從
self.module的Module層次結構中獲取一個屬性。- 引數
target (str) – 要獲取的屬性的完全限定名稱
- 返回
屬性的值。
- 返回型別
任何
注意
此 API 的向後相容性已得到保證。
- get_attr(target, args, kwargs)[源]#
執行
get_attr節點。將從self.module的Module層次結構中檢索屬性值。- 引數
target (Target) – 此節點的呼叫目標。有關語義的詳細資訊,請參閱 Node
args (Tuple) – 此呼叫的位置引數元組
kwargs (Dict) – 此呼叫的關鍵字引數字典
- 返回
檢索到的屬性值
- 返回型別
任何
注意
此 API 的向後相容性已得到保證。
- map_nodes_to_values(args, n)[源]#
遞迴地遍歷
args,並在當前執行環境中查詢每個Node的具體值。- 引數
args (Argument) – 用於查詢具體值的元資料結構
n (Node) –
args所屬的節點。這僅用於錯誤報告。
- 返回型別
Optional[Union[tuple[‘Argument’, …], Sequence[Argument], Mapping[str, Argument], slice, range, Node, str, int, float, bool, complex, dtype, Tensor, device, memory_format, layout, OpOverload, SymInt, SymBool, SymFloat]]
注意
此 API 的向後相容性已得到保證。
- output(target, args, kwargs)[源]#
執行
output節點。這實際上只是檢索output節點引用的值並返回它。- 引數
target (Target) – 此節點的呼叫目標。有關語義的詳細資訊,請參閱 Node
args (Tuple) – 此呼叫的位置引數元組
kwargs (Dict) – 此呼叫的關鍵字引數字典
- 返回
輸出節點引用的返回值
- 返回型別
任何
注意
此 API 的向後相容性已得到保證。
- placeholder(target, args, kwargs)[源]#
執行
placeholder節點。請注意,這是有狀態的:Interpreter維護一個關於傳遞給run的引數的內部迭代器,此方法返回該迭代器的 next()。- 引數
target (Target) – 此節點的呼叫目標。有關語義的詳細資訊,請參閱 Node
args (Tuple) – 此呼叫的位置引數元組
kwargs (Dict) – 此呼叫的關鍵字引數字典
- 返回
檢索到的引數值。
- 返回型別
任何
注意
此 API 的向後相容性已得到保證。
- class torch.fx.Transformer(module)[源]#
Transformer是一種特殊的直譯器,它生成新的Module。它公開了一個transform()方法,該方法返回轉換後的Module。Transformer不需要引數即可執行,而Interpreter需要。Transformer完全以符號方式工作。示例
假設我們要將所有
torch.neg例項與torch.sigmoid相互替換(包括它們的Tensor方法等效項)。我們可以像這樣繼承Transformerclass NegSigmSwapXformer(Transformer): def call_function( self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any], ) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(target, args, kwargs) def call_method( self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any], ) -> Any: if target == "neg": call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(target, args, kwargs) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) transformed: torch.nn.Module = NegSigmSwapXformer(gm).transform() input = torch.randn(3, 4) torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid())
- 引數
module (GraphModule) – 要轉換的
Module。
注意
此 API 的向後相容性已得到保證。
- get_attr(target, args, kwargs)[源]#
執行
get_attr節點。在Transformer中,這被重寫為將一個新的get_attr節點插入到輸出圖中。- 引數
target (Target) – 此節點的呼叫目標。有關語義的詳細資訊,請參閱 Node
args (Tuple) – 此呼叫的位置引數元組
kwargs (Dict) – 此呼叫的關鍵字引數字典
- 返回型別
注意
此 API 的向後相容性已得到保證。
- placeholder(target, args, kwargs)[源]#
執行
placeholder節點。在Transformer中,這被重寫為將一個新的placeholder插入到輸出圖中。- 引數
target (Target) – 此節點的呼叫目標。有關語義的詳細資訊,請參閱 Node
args (Tuple) – 此呼叫的位置引數元組
kwargs (Dict) – 此呼叫的關鍵字引數字典
- 返回型別
注意
此 API 的向後相容性已得到保證。
- torch.fx.replace_pattern(gm, pattern, replacement)[源]#
匹配 GraphModule (
gm) 的 Graph 中所有可能的非重疊的運算元集及其資料依賴性 (pattern),然後將每個匹配到的子圖替換為另一個子圖 (replacement)。- 引數
gm (GraphModule) – 包裝要操作的 Graph 的 GraphModule
pattern (Union[Callable, GraphModule]) – 在
gm中用於匹配替換的子圖replacement (Union[Callable, GraphModule]) – 用於替換
pattern的子圖
- 返回
一個
Match物件列表,表示原始圖中pattern匹配到的位置。如果沒有任何匹配,則列表為空。Match定義為class Match(NamedTuple): # Node from which the match was found anchor: Node # Maps nodes in the pattern subgraph to nodes in the larger graph nodes_map: Dict[Node, Node]
- 返回型別
List[Match]
示例
import torch from torch.fx import symbolic_trace, subgraph_rewriter class M(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x, w1, w2): m1 = torch.cat([w1, w2]).sum() m2 = torch.cat([w1, w2]).sum() return x + torch.max(m1) + torch.max(m2) def pattern(w1, w2): return torch.cat([w1, w2]) def replacement(w1, w2): return torch.stack([w1, w2]) traced_module = symbolic_trace(M()) subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
上述程式碼將首先在
traced_module的forward方法中匹配pattern。模式匹配基於使用-定義關係,而不是節點名稱。例如,如果您在pattern中有p = torch.cat([a, b]),那麼您可以在原始forward函式中匹配m = torch.cat([a, b]),即使變數名不同(pvsm)。pattern可呼叫物件中的return語句僅基於其值進行匹配;它可能匹配也可能不匹配到較大圖中的return語句。換句話說,模式不必延伸到較大圖的末尾。當模式匹配成功時,它將從較大函式中移除並被
replacement替換。如果在較大函式中有多個pattern的匹配項,則每個非重疊的匹配項都將被替換。在匹配重疊的情況下,將替換重疊匹配項集中找到的第一個匹配項。(這裡的“第一個”定義為拓撲排序的使用-定義關係中的第一個節點。在大多數情況下,第一個節點是直接出現在self之後的引數,而最後一個節點是函式返回的任何內容。)需要注意的一點是,
pattern可呼叫物件的引數必須在可呼叫物件本身中使用,並且replacement可呼叫物件的引數必須與模式匹配。第一個規則解釋了為什麼在上面的程式碼塊中,forward函式具有引數x, w1, w2,而pattern函式只有引數w1, w2。pattern不使用x,因此不應將x指定為引數。作為第二個規則的示例,請考慮替換def pattern(x, y): return torch.neg(x) + torch.relu(y)
替換
def replacement(x, y): return torch.relu(x)
在這種情況下,
replacement需要與pattern相同數量的引數(x和y),即使引數y在replacement中未使用。呼叫
subgraph_rewriter.replace_pattern後,生成的 Python 程式碼如下所示:def forward(self, x, w1, w2): stack_1 = torch.stack([w1, w2]) sum_1 = stack_1.sum() stack_2 = torch.stack([w1, w2]) sum_2 = stack_2.sum() max_1 = torch.max(sum_1) add_1 = x + max_1 max_2 = torch.max(sum_2) add_2 = add_1 + max_2 return add_2
注意
此 API 的向後相容性已得到保證。