在哪裡應用 torch.compile?#
建立時間:2025 年 7 月 28 日 | 最後更新時間:2025 年 7 月 28 日
我們建議將 torch.compile 應用於不會導致過度問題的最高級別函式。通常情況下,它是
您的
train或eval步驟,包含最佳化器但不包含迴圈,您的頂級
nn.Module或一些子
nn.Module。
torch.compile 尤其不擅長處理 DDP 或 FSDP 等分散式包裝器模組,因此請考慮將 torch.compile 應用於傳遞給包裝器的內部模組。
# inference
model = ...
model.compile()
for _ in range(N_ITERS):
inp = ...
out = model(inp)
# training
model = ...
opt = torch.optim.Adam(model.parameters())
@torch.compile
def train(mod, data):
opt.zero_grad(True)
pred = mod(data[0])
loss = torch.nn.CrossEntropyLoss()(pred, data[1])
loss.backward()
opt.step()
for _ in range(N_ITERS):
inp = ...
train(model, inp)
# DistributedDataParallel
model = ...
model.compile()
model_ddp = DistributedDataParallel(model, ...)
for _ in range(N_ITERS):
inp = ...
out = model_ddp(inp)
compile(model) 與 model.compile()#
由於 torch.compile 與 nn.Module 例項的互動方式存在細微差別,因此如果您希望將 nn.Module 例項作為頂級函式進行編譯,我們建議使用 nn.Module 例項的 .compile() 方法。巢狀的模組呼叫將被正確跟蹤 - 在這種情況下無需呼叫 .compile()。
# DO NOT DO THIS
model = MyModel()
model = torch.compile(model)
model(inp)
# DO THIS
model = MyModel()
model.compile()
model(inp)
# this is also acceptable
@torch.compile
def fn(model, inp):
return model(inp)
model = MyModel()
fn(model, inp)