評價此頁
torch.compiler.set_stance">

使用 torch.compiler.set_stance 進行動態編譯控制#

作者: William Wen

torch.compiler.set_stance 是一個 torch.compiler API,它允許您在不重新應用 torch.compile 的情況下,改變模型在不同調用時的 torch.compile 行為。

本示例提供了一些關於如何使用 torch.compiler.set_stance 的例子。

先決條件#

  • torch >= 2.6

描述#

torch.compile.set_stance 可以用作裝飾器、上下文管理器或原始函式,來改變模型在不同調用時的 torch.compile 行為。

在下面的示例中,"force_eager" 狀態會忽略所有 torch.compile 指令。

import torch


@torch.compile
def foo(x):
    if torch.compiler.is_compiling():
        # torch.compile is active
        return x + 1
    else:
        # torch.compile is not active
        return x - 1


inp = torch.zeros(3)

print(foo(inp))  # compiled, prints 1
tensor([1., 1., 1.])

示例裝飾器用法

@torch.compiler.set_stance("force_eager")
def bar(x):
    # force disable the compiler
    return foo(x)


print(bar(inp))  # not compiled, prints -1
tensor([-1., -1., -1.])

示例上下文管理器用法

with torch.compiler.set_stance("force_eager"):
    print(foo(inp))  # not compiled, prints -1
tensor([-1., -1., -1.])

示例原始函式用法

torch.compiler.set_stance("force_eager")
print(foo(inp))  # not compiled, prints -1
torch.compiler.set_stance("default")

print(foo(inp))  # compiled, prints 1
tensor([-1., -1., -1.])
tensor([1., 1., 1.])

torch.compile 狀態只能在任何 torch.compile 區域的外部更改。否則會引發錯誤。

@torch.compile
def baz(x):
    # error!
    with torch.compiler.set_stance("force_eager"):
        return x + 1


try:
    baz(inp)
except Exception as e:
    print(e)


@torch.compiler.set_stance("force_eager")
def inner(x):
    return x + 1


@torch.compile
def outer(x):
    # error!
    return inner(x)


try:
    outer(inp)
except Exception as e:
    print(e)
Attempt to trace forbidden callable <function set_stance at 0x7f66ad509750>

from user code:
   File "/var/lib/workspace/recipes_source/torch_compiler_set_stance_tutorial.py", line 85, in baz
    with torch.compiler.set_stance("force_eager"):

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Attempt to trace forbidden callable <function inner at 0x7f66c86d3760>

from user code:
   File "/var/lib/workspace/recipes_source/torch_compiler_set_stance_tutorial.py", line 103, in outer
    return inner(x)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
其他狀態包括:
  • "default":預設狀態,用於正常編譯。

  • "eager_on_recompile":在需要重新編譯時,以立即執行模式執行程式碼。如果存在對輸入有效的快取編譯程式碼,則仍會使用它。

  • "fail_on_recompile":在重新編譯函式時引發錯誤。

有關更多狀態和選項,請參閱 torch.compiler.set_stance文件頁面。未來也可能新增更多狀態/選項。

示例#

防止重新編譯#

有些模型不期望任何重新編譯——例如,您可能有輸入始終具有相同的形狀。由於重新編譯可能成本高昂,我們可能希望在嘗試重新編譯時報錯,以便檢測和修復重新編譯的情況。"fail_on_recompilation" 狀態可用於此目的。

@torch.compile
def my_big_model(x):
    return torch.relu(x)


# first compilation
my_big_model(torch.randn(3))

with torch.compiler.set_stance("fail_on_recompile"):
    my_big_model(torch.randn(3))  # no recompilation - OK
    try:
        my_big_model(torch.randn(4))  # recompilation - error
    except Exception as e:
        print(e)
Detected recompile when torch.compile stance is 'fail_on_recompile'. filename: '/var/lib/workspace/recipes_source/torch_compiler_set_stance_tutorial.py', function name: 'my_big_model', line number: 0

如果報錯過於 disruptive,我們可以改用 "eager_on_recompile",它將導致 torch.compile 回退到立即執行模式而不是報錯。如果預計重新編譯不會頻繁發生,但一旦需要,我們寧願承擔立即執行的成本而不是重新編譯的成本,那麼這可能很有用。

@torch.compile
def my_huge_model(x):
    if torch.compiler.is_compiling():
        return x + 1
    else:
        return x - 1


# first compilation
print(my_huge_model(torch.zeros(3)))  # 1

with torch.compiler.set_stance("eager_on_recompile"):
    print(my_huge_model(torch.zeros(3)))  # 1
    print(my_huge_model(torch.zeros(4)))  # -1
    print(my_huge_model(torch.zeros(3)))  # 1
tensor([1., 1., 1.])
tensor([1., 1., 1.])
tensor([-1., -1., -1., -1.])
tensor([1., 1., 1.])

衡量效能提升#

torch.compiler.set_stance 可用於比較立即執行模式和編譯模式的效能,而無需定義單獨的立即執行模型。

# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000


@torch.compile
def my_gigantic_model(x, y):
    x = x @ y
    x = x @ y
    x = x @ y
    return x


inps = torch.randn(5, 5), torch.randn(5, 5)

with torch.compiler.set_stance("force_eager"):
    print("eager:", timed(lambda: my_gigantic_model(*inps))[1])

# warmups
for _ in range(3):
    my_gigantic_model(*inps)

print("compiled:", timed(lambda: my_gigantic_model(*inps))[1])
eager: 0.00026208001375198364
compiled: 0.00012691199779510498

更早崩潰#

在使用 "force_eager" 狀態執行一次立即執行迭代,然後再執行一次編譯迭代,可以幫助我們在嘗試非常耗時的編譯之前捕獲與 torch.compile 無關的錯誤。

@torch.compile
def my_humongous_model(x):
    return torch.sin(x, x)


try:
    with torch.compiler.set_stance("force_eager"):
        print(my_humongous_model(torch.randn(3)))
    # this call to the compiled model won't run
    print(my_humongous_model(torch.randn(3)))
except Exception as e:
    print(e)
sin() takes 1 positional argument but 2 were given

結論#

在本示例中,我們學習瞭如何使用 torch.compiler.set_stance API,在不重新應用 torch.compile 的情況下,修改模型在不同調用時的行為。本示例演示瞭如何將 torch.compiler.set_stance 用作裝飾器、上下文管理器或原始函式,來控制 force_eagerdefaulteager_on_recompilefail_on_recompile 等編譯狀態。

更多資訊,請參閱:torch.compiler.set_stance API 文件

指令碼總執行時間: (0 分鐘 10.660 秒)