評價此頁
fullgraph=False">

在哪裡應用 torch.compile?#

建立時間:2025 年 7 月 28 日 | 最後更新時間:2025 年 7 月 28 日

我們建議將 torch.compile 應用於不會導致過度問題的最高級別函式。通常情況下,它是

  • 您的 traineval 步驟,包含最佳化器但不包含迴圈,

  • 您的頂級 nn.Module

  • 或一些子 nn.Module

torch.compile 尤其不擅長處理 DDP 或 FSDP 等分散式包裝器模組,因此請考慮將 torch.compile 應用於傳遞給包裝器的內部模組。

# inference
model = ...
model.compile()

for _ in range(N_ITERS):
    inp = ...
    out = model(inp)
# training
model = ...
opt = torch.optim.Adam(model.parameters())

@torch.compile
def train(mod, data):
    opt.zero_grad(True)
    pred = mod(data[0])
    loss = torch.nn.CrossEntropyLoss()(pred, data[1])
    loss.backward()
    opt.step()

for _ in range(N_ITERS):
    inp = ...
    train(model, inp)
# DistributedDataParallel
model = ...
model.compile()
model_ddp = DistributedDataParallel(model, ...)

for _ in range(N_ITERS):
    inp = ...
    out = model_ddp(inp)

compile(model)model.compile()#

由於 torch.compilenn.Module 例項的互動方式存在細微差別,因此如果您希望將 nn.Module 例項作為頂級函式進行編譯,我們建議使用 nn.Module 例項的 .compile() 方法。巢狀的模組呼叫將被正確跟蹤 - 在這種情況下無需呼叫 .compile()

# DO NOT DO THIS
model = MyModel()
model = torch.compile(model)
model(inp)

# DO THIS
model = MyModel()
model.compile()
model(inp)

# this is also acceptable
@torch.compile
def fn(model, inp):
    return model(inp)
model = MyModel()
fn(model, inp)