評價此頁

非嚴格追蹤程式設計模型#

建立時間:2025 年 7 月 28 日 | 最後更新時間:2025 年 7 月 28 日

摘要

  • 非嚴格追蹤 是一種追蹤 Python 程式碼的方式,它比 Dynamo 更寬鬆,但可能導致靜默的錯誤。

  • 非嚴格追蹤會執行一個 Python 函式,並利用 Python 和 PyTorch 的運算子過載能力,記錄執行過程中發生的 Tensor 操作,生成一個追蹤。

  • 如果一個函式符合某些約束條件,那麼它是非嚴格可追蹤的,即該函式是純函式,並且不直接操作 Tensor.data_ptr()。

  • 非嚴格追蹤可能會特化某些變數,並將它們視為常量,將變數的值“烘焙”到追蹤中。

torch.compile 的內部元件(make_fx, AOTDispatcher)使用非嚴格追蹤torch._dynamo.nonstrict_trace 也可以在 torch.compile 的程式碼中使用,以標記需要使用非嚴格追蹤進行追蹤的程式碼段。非嚴格追蹤會執行一個 Python 函式,並利用 Python 和 PyTorch 的運算子過載能力,記錄執行過程中發生的 Tensor 操作,生成一個追蹤。

make_fx 是非嚴格追蹤的主要入口點。對於以下函式,在輸入執行時只執行頂部的分支,因此它捕獲的圖只包含該分支。

from torch.fx.experimental.proxy_tensor import make_fx
def f(x):
    if x.shape[0] > 2:
        return x ** 2 / 6
    else:
        return x * 3
x = torch.randn(3)
gm = make_fx(f, tracing_mode="fake")(x)
gm.print_readable()
class f(torch.nn.Module):
    def forward(self, x_1: "f32[3]"):
        # No stacktrace found for following nodes
        pow_1: "f32[3]" = torch.ops.aten.pow.Tensor_Scalar(x_1, 2);  x_1 = None
        div: "f32[3]" = torch.ops.aten.div.Tensor(pow_1, 6);  pow_1 = None
        return div
        
'class f(torch.nn.Module):\n    def forward(self, x_1: "f32[3]"):\n        # No stacktrace found for following nodes\n        pow_1: "f32[3]" = torch.ops.aten.pow.Tensor_Scalar(x_1, 2);  x_1 = None\n        div: "f32[3]" = torch.ops.aten.div.Tensor(pow_1, 6);  pow_1 = None\n        return div\n        '

非嚴格追蹤與 Dynamo(嚴格)追蹤的區別在於它是不安全的,也就是說,對於一個給定的函式,它捕獲的 Tensor 操作圖可能與原始函式具有不同的語義。對於一個 Python 函式,Dynamo 追蹤會捕獲 Tensor 操作圖和剩餘的位元組碼,它們的組合與 Python 函式具有相同的語義。

純函式#

非嚴格追蹤僅在純函式上是可靠的,因此只有純函式才應該進行非嚴格追蹤。

純函式是具有以下屬性的函式:

  • 確定性。 對於相同的輸入,純函式將始終返回相同的輸出。

  • 無副作用。 純函式沒有任何副作用,例如修改外部狀態或執行 I/O 操作。

  • 顯式的輸入/輸出。 所有輸入資料都必須透過函式引數傳遞,並且所有輸出都從函式中返回。

以下是一些非純函式的示例,在這些函式中,捕獲的圖與原始函式行為不同。

示例 1:無顯式輸入(例如,訪問全域性 Tensor)#

var = torch.tensor(1)
def function_with_global_access(y):
    return y + var
x = torch.tensor([0, 1, 2])
# _allow_non_fake_inputs=True is needed to capture the global variable
# for demonstration purposes.
gm = make_fx(
    function_with_global_access, tracing_mode="fake", _allow_non_fake_inputs=True
)(x)
# Non-strict Tracing captures the value of the global (1.)
print("1. call function", function_with_global_access(x))
print("1. call graph", gm(x))
# However, after changing the global, the captured graph
# produces a different result from the original function
var = torch.tensor(2)
print("2. call function", function_with_global_access(x))
print("2. call graph", gm(x))
# To capture a graph that can have a varying `var` tensor,
# it must be an explicit input:
def function_fixed(y, var):
    return y + var
var = torch.tensor(3)
gm = make_fx(function_fixed, tracing_mode="fake")(x, var)
print("3. call function", function_fixed(x, var))
print("3. call graph", gm(x, var))
var = torch.tensor(4)
print("4. call function", function_fixed(x, var))
print("4. call graph", gm(x, var))
1. call function tensor([1, 2, 3])
1. call graph tensor([1, 2, 3])
2. call function tensor([2, 3, 4])
2. call graph tensor([1, 2, 3])
3. call function tensor([3, 4, 5])
3. call graph tensor([3, 4, 5])
4. call function tensor([4, 5, 6])
4. call graph tensor([4, 5, 6])

有關原因,請參閱 特化和常量

示例 2:副作用(列印)#

def function_with_side_effect(y):
    print(y)
x = torch.tensor([0, 1, 2])
_ = function_with_side_effect(x)
tensor([0, 1, 2])

在 Python 中執行 f 會作為副作用列印一個 Tensor。

gm = make_fx(function_with_side_effect, tracing_mode="fake")(x)
FakeTensor(..., size=(3,), dtype=torch.int64)

在非嚴格追蹤期間,此打印發生在圖捕獲過程中。

_ = gm(x)

圖不儲存對 print 語句的呼叫,因此執行圖不會列印任何內容。

示例 3:副作用(列表突變)#

lst = []
def function_with_input_list_mutation(lst):
    val = lst.pop()
    return val
x = torch.tensor([0, 1, 2])
y = torch.tensor([0, 1, 2])
# Each time the function is executed, the list shrinks in size
lst = [x, y]
function_with_input_list_mutation(lst)
print("len(lst) after one call", len(lst))
function_with_input_list_mutation(lst)
print("len(lst) after two calls", len(lst))
# With Non-strict Tracing, the length of the list shrinks during
# the graph capture but not in invocations of the graph.
lst = [x, y]
gm = make_fx(function_with_input_list_mutation, tracing_mode="fake")(lst)
print("len(lst) after graph capture", len(lst))
gm(lst)
print("len(lst) after one call to graph", len(lst))
gm(lst)
print("len(lst) after two calls to graph", len(lst))
len(lst) after one call 1
len(lst) after two calls 0
len(lst) after graph capture 2
len(lst) after one call to graph 2
len(lst) after two calls to graph 2

無直接 data_ptr 操作#

直接操作 Tensor.data_ptr 是不可非嚴格追蹤的。其背後的直覺是,PyTorch 無法知道您是如何操作 data_ptr 的。

import ctypes
# Create a tensor with a single element
tensor = torch.tensor([42], dtype=torch.int32)  # Using int32 for simplicity
def function_with_data_ptr(tensor):
    # Get the data pointer
    ptr = tensor.data_ptr()
    # Cast the pointer to a ctypes pointer
    ctypes_ptr = ctypes.cast(ptr, ctypes.POINTER(ctypes.c_int32))
    # Increment the value at the pointer
    ctypes_ptr.contents.value += 1
    return tensor
try:
    make_fx(function_with_data_ptr, tracing_mode="fake")(tensor)
except Exception as e:
    print(e)
Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.com.tw/tutorials/advanced/custom_ops_landing_page.html

特化和常量#

非嚴格追蹤捕獲的圖可能在某些值上進行了特化。這意味著捕獲的圖僅對這些值有效。我們說該圖將這些值視為常量

在非嚴格追蹤期間,所有非 Tensor 變數都被視為常量。

def f(x, y):
    return x + y
x = torch.tensor([0, 1, 2])
y = 3.14
gm = make_fx(f, tracing_mode="fake")(x, y)
gm.print_readable()
class f(torch.nn.Module):
    def forward(self, x_1: "i64[3]", y_1):
        # No stacktrace found for following nodes
        add: "f32[3]" = torch.ops.aten.add.Tensor(x_1, 3.14);  x_1 = None
        return add
        
'class f(torch.nn.Module):\n    def forward(self, x_1: "i64[3]", y_1):\n        # No stacktrace found for following nodes\n        add: "f32[3]" = torch.ops.aten.add.Tensor(x_1, 3.14);  x_1 = None\n        return add\n        '

3.14 是圖中的一個常量。

非嚴格追蹤還會對輸入 Tensor 的屬性進行特化。

def f(x):
    if x.shape[0] > 2:
        return x ** 2 / 6
    else:
        return x * 3
x = torch.randn(3)
gm = make_fx(f, tracing_mode="fake")(x)
gm.print_readable()
class f(torch.nn.Module):
    def forward(self, x_1: "f32[3]"):
        # No stacktrace found for following nodes
        pow_1: "f32[3]" = torch.ops.aten.pow.Tensor_Scalar(x_1, 2);  x_1 = None
        div: "f32[3]" = torch.ops.aten.div.Tensor(pow_1, 6);  pow_1 = None
        return div
        
'class f(torch.nn.Module):\n    def forward(self, x_1: "f32[3]"):\n        # No stacktrace found for following nodes\n        pow_1: "f32[3]" = torch.ops.aten.pow.Tensor_Scalar(x_1, 2);  x_1 = None\n        div: "f32[3]" = torch.ops.aten.div.Tensor(pow_1, 6);  pow_1 = None\n        return div\n        '

它還會對未直接傳遞給函式的任何變數進行特化。

var = torch.tensor(1)
def f(x):
    return x + y
x = torch.randn(3)
gm = make_fx(f, tracing_mode="fake")(x)
gm.print_readable()
class f(torch.nn.Module):
    def forward(self, x_1: "f32[3]"):
        # No stacktrace found for following nodes
        add: "f32[3]" = torch.ops.aten.add.Tensor(x_1, 3.14);  x_1 = None
        return add
        
'class f(torch.nn.Module):\n    def forward(self, x_1: "f32[3]"):\n        # No stacktrace found for following nodes\n        add: "f32[3]" = torch.ops.aten.add.Tensor(x_1, 3.14);  x_1 = None\n        return add\n        '