在 ATen IR 上編寫圖變換#
建立日期:2025 年 6 月 11 日 | 最後更新日期:2025 年 6 月 11 日
Passes#
由於 ATen IR 位於 FX Graph/GraphModule 層,因此為 FX Graph 編寫的任何變換都可以輕鬆應用於 ATen IR。如果您熟悉編寫 FX 圖變換,那麼這與您已經瞭解的相同。
編寫變換的最直接方法是遍歷給定的圖並直接操作圖中的節點。
例如,假設我們要將 torch.ops.aten.add.Tensor() 呼叫替換為 torch.ops.aten.mul.Tensor() 呼叫
import torch
def replace_add_with_mul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
node.target = torch.ops.aten.mul.Tensor
我們還可以透過 FX 實用函式刪除和新增新節點,這些函式可以在 Graph 文件中找到。例如,如果我們想在 add 呼叫之後插入一個 torch.ops.aten.relu.default()
import torch
def insert_relu_after_add(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
# Specifies the insertion point. Any nodes added to the graph within
# this scope will be inserted after `node`
with gm.graph.inserting_after(node):
# Insert a new `call_function` node with op `torch.ops.aten.relu.default`
new_relu_node = gm.graph.call_function(torch.ops.aten.relu.default, args=(node,))
# Replace all the places that use `node` to now use the `new_relu_node`
node.replace_all_uses_with(new_relu_node)
總的來說,變換可以大致分為幾個軸
軸 A:1. 建立一對多對映(例如,分解) 2. 建立多對一對映(例如,融合)
軸 B:1. 正向迭代(例如,形狀傳播) 2. 反向迭代(例如,死程式碼消除)
軸 C:1. 依賴於區域性節點資訊(例如,out-variant 轉換) 2. 依賴於全域性圖資訊(例如,記憶體規劃)
我們對這些用例頻率的預測是:1. A.1,B.1,C.1 2. A.2 3. B.2,C.2
雖然我們可以透過直接操作圖來實現所有圖變換,但我們也提供了一些輔助工具,以便在處理 1 級和 2 級用例時更加方便。
Transformer#
對於 1 級用例(建立一對多對映、進行正向迭代和檢視區域性節點資訊),我們可以利用 Transformer 類來執行每個節點並重新建立一個圖,但會應用指定的變換。
一對一 Pass#
對於一對一對映的示例,如果我們想用另一個 op B 替換 op A,我們可以執行 GraphModule,每次看到 op A 時,返回 op B。
一個例子是
class ReplaceAddWithMul(torch.fx.Transformer):
def call_function(self, target, args, kwargs):
if target != torch.ops.aten.add.Tensor:
return super().call_function(target, args, kwargs)
return super().call_function(torch.ops.aten.mul.Tensor, args, kwargs)
transformed_graph_module = ReplaceAddWithMul(graph_module).transform()
對 super().call_function(target, args, kwargs, meta) 的呼叫會建立一個 call_function FX 節點,並返回使用給定引數執行操作的結果。
一對多 Pass#
如果我們想進行一對多對映,例如用兩個其他 op B 和 C 替換 op A,那麼我們將呼叫兩次 super().call_function 來建立兩個 FX 節點,一個 op B,另一個 op C,並返回執行 op C 的結果。
例如
class ReplaceAddWithMulSub(torch.fx.Transformer):
"""
Original:
def f(x, y):
return x + y
After pass:
def f(x, y):
z = x * y
return z - y
"""
def call_function(self, target, args, kwargs):
if target != torch.ops.aten.add.Tensor:
return super().call_function(target, args, kwargs)
x, y = args
mul_res = super().call_function(torch.ops.aten.mul.Tensor, args, {})
return super().call_function(torch.ops.aten.sub.Tensor, (mul_res, y), {})
transformed_graph_module = ReplaceAddWithMulSub(graph_module).transform()
一對零 Pass#
如果我們想移除一個 op,我們可以直接返回傳遞給函式的該值
class RemoveDetachPass(torch.fx.Transformer):
def call_function(self, target, args, kwargs):
if target not in (
torch.ops.aten.detach.default,
torch.ops.aten.detach_copy.default,
):
return super().call_function(target, args, kwargs, meta)
assert len(args) == 1
return args[0]
transformed_graph_module = RemoveDetachPass(graph_module).transform()
利用區域性資訊#
利用區域性節點資訊的例子是,如果我們想將圖中的所有標量轉換為張量,我們可以執行給定的 fx.GraphModule,並且對於包含標量的每個引數,我們將其轉換為張量。它可能看起來像
def args_map(target, fn, args, kwargs):
assert isinstance(args, tuple)
assert isinstance(kwargs, dict)
args = list(args)
kwargs = kwargs.copy()
# Update the argument based on the function passed
def update(key, args, schema):
args[key] = fn(args[key], schema)
# Update each argument in the schema
for i, schema in enumerate(target._schema.arguments):
if schema.name in kwargs:
update(schema.name, kwargs, schema)
elif not schema.kwarg_only and i < len(args):
update(i, args, schema)
return tuple(args), kwargs
class ScalarToTensorPass(torch.fx.Transformer):
def call_function(self, target, args, kwargs):
breakpoint()
def try_coerce(value, arg):
return (
torch.tensor(value)
if isinstance(value, (float, int, bool))
and type(arg.type) == torch.TensorType
else value
)
args, kwargs = args_map(target, try_coerce, args, kwargs)
return super().call_function(target, args, kwargs)
transformed_graph_module = ScalarToTensorPass(graph_module).transform()
子圖重寫器#
要建立多對一對映,我們可以利用 FX 的 subgraph rewriter。給定一個 pattern,它會建立一個匹配該模式的操作子圖,然後將每個匹配的子圖替換為 replacement。
注意
This is an inplace operation.
pattern 和 replacement 輸入必須是可呼叫函式或包含與圖中使用(ATen ops)的相同操作的 GraphModules,這樣子圖重寫器才能在圖中找到正確的模式。模式/替換可呼叫物件的輸入在匹配時將被視為萬用字元。
一個例子
from torch.fx import subgraph_rewriter
def replace_patterns(graph_module):
def pattern(x, y):
x = torch.ops.aten.add.Tensor(x, y)
x = torch.ops.aten.mul.Tensor(x, y)
return x
def replacement(x, y):
return torch.ops.aten.sub.Tensor(x, y)
replaced_patterns = subgraph_rewriter.replace_pattern_with_filters(
traced_module, pattern, replacement
)
子圖重寫器返回一個 ReplacedPatterns 列表
@dataclass
class ReplacedPatterns:
# Node from which the match was found
anchor: Node
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node]
# List of nodes that were added into the graph
replacements: List[Node]
注意
The nodes created by the subgraph rewriter will not have the metadata that
is populated in the matched nodes, but you can use
`ReplacedPatterns.nodes_map` to find the nodes in the original graph that
were matched, and `ReplacedPatterns.replacements` to find the nodes that
were replaced in the transformed graph.
Pass Manager#
PassManager 是一個用於在給定圖模組上執行多個 pass 的類。在初始化 PassManager 例項時,我們會傳入一個我們想要執行的 pass 列表,並設定幾個標誌。要在圖模組上執行 pass 集合,我們可以直接將圖模組傳遞給 PassManager 例項。
一個例子
from torch.fx.passes.infra.pass_manager import PassManager
pm = PassManager(
passes=[replace_add_with_div, replace_div_with_mul],
run_checks_after_each_pass=True,
suppress_check_failures=False,
)
graph_module_out = pm(graph_module)
要新增一組常見的檢查,在每個 pass 執行後進行,我們可以呼叫函式 set_checks(check: Callable),該函式接受一個可呼叫函式作為輸入。如果設定了 run_checks_after_each_pass 標誌,則在圖模組上執行每個 pass 後將呼叫 check。
一個例子
pm = PassManager(passes=[replace_add_with_div, replace_div_with_mul])
def check_div_target(graph_module):
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target != torch.div:
raise ValueError("Target should be div!")
pm.add_checks(check_div_target)
pm(graph_module) # raises ValueError after replace_div_with_mul pass
Partitioner#
我們可以使用幾種常見的基於 FX 圖的分割槽器來對圖進行分割槽。
子圖匹配器#
要查詢圖中與特定模式匹配的子圖,我們可以利用 FX 的 SubgraphMatcher。
類屬性
pattern (Graph):目標匹配模式。圖中的佔位符節點在匹配時將被視為萬用字元。match_output (bool):如果為 True,則模式圖中的輸出節點將被視為目標模式的一部分。如果為 False,則在匹配期間忽略輸出節點。match_placeholder (bool):如果為 True,則模式圖中的佔位符節點將被視為目標模式的一部分。如果為 False,則佔位符節點將用作萬用字元。remove_overlapping_matches (bool):如果為 True,在重疊匹配的情況下,將只返回第一個匹配項。ignore_literals (bool):如果為 True,將不檢查字面量是否相等,而是將它們視為萬用字元。
一個例子
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
class LargeModel(torch.nn.Module):
def __init__(self):
super().__init__()
self._weight = torch.nn.Parameter(torch.ones(3, 3))
self._bias = torch.nn.Parameter(torch.ones(3, 3))
def forward(self, x):
return torch.ops.aten.addmm.default(self._bias, x, self._weight)
large_model_graph = torch.export(LargeModel(), inputs).graph
class PatternModel(torch.nn.Module):
def __init__(self):
super().__init__()
self._weight_1 = torch.nn.Parameter(torch.ones(5, 5))
self._bias_1 = torch.nn.Parameter(torch.ones(5, 5))
def forward(self, x):
return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1)
pattern_graph = torch.export(PatternModel(), inputs).graph
subgraph_matcher = SubgraphMatcher(pattern_graph)
match_result = subgraph_matcher.match(large_model_graph)
match 函式返回一個 InternalMatch 列表
@dataclass
class InternalMatch():
# Nodes from which the match was found
anchors: List[Node]
# Maps nodes in the pattern subgraph to nodes in the larger graph
nodes_map: Dict[Node, Node] = field(default_factory=dict)
# Nodes in target graph that are matched placeholder in pattern
placeholder_nodes: List[Node] = field(default_factory=list)
# Nodes in matched subgraph returned by output
returning_nodes: List[Node] = field(default_factory=list)
基於能力的 Partitioner#
要查詢支援特定不變性的最大節點子圖,我們可以利用 FX 的 CapabilityBasedPartitioner。
類屬性
graph_module (torch.fx.GraphModule):我們正在對其進行分割槽的圖模組。operator_support (OperatorSupportBase):用於確定圖中的節點是否受分割槽支援的物件。allows_single_node_partition (bool):如果為 True,則允許形成單節點分割槽。non_compute_ops (Optional[Sequence[str]]):一組被視為“非計算”的操作(例如torch.ops.aten.view和_operator.getitem),以便分割槽器不會建立僅包含這些非計算操作的圖。allowed_single_node_partition_ops (Optional[Sequence[str]]):允許在單節點分割槽中使用的操作集。
OperatorSupportBase 類由分割槽器用來確定圖中的特定節點是否屬於該分割槽。這是透過覆蓋 is_node_supported 函式來完成的。您可以透過使用 chain(如果任何 OperatorSupportBase 返回 False,則返回 False)和 any_chain(如果任何 OperatorSupportBase 返回 True,則返回 True)來連結多個 OperatorSupportBase。
一個例子
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
class AddMulOperatorSupport(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in [
torch.ops.aten.add.Tensor, torch.ops.aten.mul.Tensor,
]
capability_partitioner = CapabilityBasedPartitioner(
graph_module,
op_support,
)
# Returns a list of partitions (list of nodes that belong in each partition)
partition_list = capability_partitioner.propose_partitions()
# Fuses the partitions into graph modules and inserts `call_module` nodes in the graph
fused_graph_module = capability_partitioner.fuse_partitions(partition_list)