評價此頁

torch.export 流、常見挑戰及解決方案演示#

作者: Ankith Gunapal, Jordi Ramon, Marcos Carranza

torch.export 入門教程 中,我們學習瞭如何使用 torch.export。本教程在前一教程的基礎上進行了擴充套件,透過程式碼演示了匯出流行模型的流程,並解決了使用 torch.export 時可能遇到的常見挑戰。

在本教程中,您將學習如何針對以下用例匯出模型

選擇這四種模型是為了演示 torch.export 的獨特功能,以及在實現過程中遇到的一些實際考慮和問題。

先決條件#

  • PyTorch 2.4 或更高版本

  • torch.export 和 PyTorch Eager 推理有基本瞭解。

torch.export 的關鍵要求:無圖中斷#

torch.compile 透過使用 JIT 將 PyTorch 程式碼編譯為最佳化核心來加速 PyTorch 程式碼。它使用 TorchDynamo 最佳化給定模型,並建立一個最佳化的圖,然後使用 API 中指定的後端將其降低到硬體。當 TorchDynamo 遇到不支援的 Python 功能時,它會中斷計算圖,讓預設的 Python 直譯器處理不支援的程式碼,然後恢復捕獲圖。計算圖中的這種中斷稱為圖中斷

torch.exporttorch.compile 的一個關鍵區別是,torch.export 不支援圖中斷,這意味著您要匯出的整個模型或模型的一部分需要是一個單一的圖。這是因為處理圖中斷涉及使用預設的 Python 評估來解釋不支援的操作,這與 torch.export 的設計目的不相容。您可以在此 連結 中閱讀有關各種 PyTorch 框架之間差異的詳細資訊。

您可以使用以下命令識別程式中的圖中斷

TORCH_LOGS="graph_breaks" python <file_name>.py

您需要修改程式以消除圖中斷。一旦解決,您就可以匯出模型了。PyTorch 對流行的 HuggingFace 和 TIMM 模型執行 torch.compile夜間基準測試。其中大多數模型都沒有圖中斷。

此配方中的模型沒有圖中斷,但會因 torch.export 而失敗。

影片分類#

MViT 是一類基於 MultiScale Vision Transformers 的模型。該模型已使用 Kinetics-400 資料集 針對影片分類進行了訓練。該模型及其相關資料集可用於遊戲場景中的動作識別。

下面的程式碼透過使用 batch_size=2 進行追蹤來匯出 MViT,然後檢查匯出的程式是否可以使用 batch_size=4 執行。

import numpy as np
import torch
from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b
import traceback as tb

model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT)

# Create a batch of 2 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(2,16, 224, 224, 3)
# Transpose to get [1, 3, num_clips, height, width].
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))

# Export the model.
exported_program = torch.export.export(
    model,
    (input_frames,),
)

# Create a batch of 4 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(4,16, 224, 224, 3)
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
try:
    exported_program.module()(input_frames)
except Exception:
    tb.print_exc()

錯誤:靜態批次大小#

    raise RuntimeError(
RuntimeError: Expected input at *args[0].shape[0] to be equal to 2, but got 4

預設情況下,匯出流程會假設所有輸入形狀都是靜態的,因此如果您使用與追蹤時使用的輸入形狀不同的輸入形狀執行程式,您將遇到錯誤。

解決方案#

為了解決此錯誤,我們指定輸入的第一維(batch_size)為動態,指定了預期的 batch_size 範圍。在下面顯示的更正示例中,我們指定預期的 batch_size 範圍可以是從 1 到 16。一個需要注意的細節是 min=2 並非錯誤,這一點在 0/1 特殊化問題 中得到了解釋。有關 torch.export 動態形狀的詳細描述,請參閱匯出教程。下面提供的程式碼演示瞭如何匯出具有動態批次大小的 mViT。

import numpy as np
import torch
from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b
import traceback as tb


model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT)

# Create a batch of 2 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(2,16, 224, 224, 3)

# Transpose to get [1, 3, num_clips, height, width].
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))

# Export the model.
batch_dim = torch.export.Dim("batch", min=2, max=16)
exported_program = torch.export.export(
    model,
    (input_frames,),
    # Specify the first dimension of the input x as dynamic
    dynamic_shapes={"x": {0: batch_dim}},
)

# Create a batch of 4 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(4,16, 224, 224, 3)
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
try:
    exported_program.module()(input_frames)
except Exception:
    tb.print_exc()

自動語音識別#

自動語音識別 (ASR) 是利用機器學習將口語轉換為文字。 Whisper 是來自 OpenAI 的基於 Transformer 的編碼器-解碼器模型,它在 680,000 小時的 ASR 和語音翻譯標記資料上進行了訓練。下面的程式碼嘗試匯出用於 ASR 的 whisper-tiny 模型。

import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

# load model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")

# dummy inputs for exporting the model
input_features = torch.randn(1,80, 3000)
attention_mask = torch.ones(1, 3000)
decoder_input_ids = torch.tensor([[1, 1, 1 , 1]]) * model.config.decoder_start_token_id

model.eval()

exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(input_features, attention_mask, decoder_input_ids,))

錯誤:TorchDynamo 嚴格追蹤#

torch._dynamo.exc.InternalTorchDynamoError: AttributeError: 'DynamicCache' object has no attribute 'key_cache'

預設情況下,torch.export 使用 TorchDynamo(一個位元組碼分析引擎)追蹤您的程式碼,該引擎會符號化地分析您的程式碼並構建圖。這種分析提供了更強的安全性保證,但並非所有 Python 程式碼都支援。當我們使用預設的嚴格模式匯出 whisper-tiny 模型時,由於存在不支援的功能,它通常會在 Dynamo 中返回錯誤。要理解為什麼這會在 Dynamo 中導致錯誤,您可以參考此 GitHub issue

解決方案#

為了解決上述錯誤,torch.export 支援 non_strict 模式,在該模式下,程式使用 Python 直譯器進行追蹤,這與 PyTorch Eager 執行類似。唯一的區別是所有 Tensor 物件都將被 ProxyTensors 替換,後者會將它們的所有操作記錄到圖中。透過使用 strict=False,我們可以成功匯出程式。

import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

# load model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")

# dummy inputs for exporting the model
input_features = torch.randn(1,80, 3000)
attention_mask = torch.ones(1, 3000)
decoder_input_ids = torch.tensor([[1, 1, 1 , 1]]) * model.config.decoder_start_token_id

model.eval()

exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(input_features, attention_mask, decoder_input_ids,), strict=False)

影像字幕生成#

影像字幕生成 是用文字定義影像內容的任務。在遊戲場景中,影像字幕生成可以透過動態生成場景中各種遊戲物件的文字描述來增強遊戲體驗,從而為玩家提供更多細節。 BLIP 是由 Salesforce Research 釋出的、流行的影像字幕生成模型。下面的程式碼嘗試使用 batch_size=1 匯出 BLIP。

import torch
from models.blip import blip_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_size = 384
image = torch.randn(1, 3,384,384).to(device)
caption_input = ""

model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')
model.eval()
model = model.to(device)

exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(image,caption_input,), strict=False)

錯誤:無法修改具有凍結儲存的張量#

在匯出模型時,它可能會失敗,因為模型實現可能包含某些 Python 操作,而這些操作尚未得到 torch.export 的支援。其中一些失敗可能有解決方法。BLIP 就是一個例子,原始模型會失敗,但可以透過對程式碼進行少量更改來解決。 torch.exportExportDB 中列出了支援和不支援操作的常見情況,並展示瞭如何修改程式碼以使其相容匯出。

File "/BLIP/models/blip.py", line 112, in forward
    text.input_ids[:,0] = self.tokenizer.bos_token_id
  File "/anaconda3/envs/export/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py", line 545, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
RuntimeError: cannot mutate tensors with frozen storage

解決方案#

克隆匯出失敗的張量

text.input_ids = text.input_ids.clone() # clone the tensor
text.input_ids[:,0] = self.tokenizer.bos_token_id

注意

此限制已在 PyTorch 2.7 夜間版本中放寬。這應該可以直接在 PyTorch 2.7 中正常工作。

可提示影像分割#

影像分割 是一種計算機視覺技術,它根據畫素的特徵將數字影像劃分為不同的組,即段。 Segment Anything Model (SAM) 引入了可提示影像分割,它可以根據指示所需物件的提示來預測物件掩碼。 SAM 2 是首個用於跨影像和影片分割物件的統一模型。 SAM2ImagePredictor 類提供了模型用於提示模型的簡單介面。該模型可以接受點提示和框提示以及前一迭代預測的掩碼作為輸入。由於 SAM2 對物件跟蹤提供了強大的零樣本效能,因此可用於跟蹤場景中的遊戲物件。

SAM2ImagePredictor 的 predict 方法中的張量操作發生在 _predict 方法中。因此,我們嘗試像這樣匯出。

ep = torch.export.export(
    self._predict,
    args=(unnorm_coords, labels, unnorm_box, mask_input, multimask_output),
    kwargs={"return_logits": return_logits},
    strict=False,
)

錯誤:模型不是 torch.nn.Module 型別#

torch.export 需要模組是 torch.nn.Module 型別。但是,我們嘗試匯出的模組是一個類方法。因此會報錯。

Traceback (most recent call last):
  File "/sam2/image_predict.py", line 20, in <module>
    masks, scores, _ = predictor.predict(
  File "/sam2/sam2/sam2_image_predictor.py", line 312, in predict
    ep = torch.export.export(
  File "python3.10/site-packages/torch/export/__init__.py", line 359, in export
    raise ValueError(
ValueError: Expected `mod` to be an instance of `torch.nn.Module`, got <class 'method'>.

解決方案#

我們編寫了一個繼承自 torch.nn.Module 的輔助類,並在該類的 forward 方法中呼叫 _predict method。完整的程式碼可以在 這裡 找到。

class ExportHelper(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(_, *args, **kwargs):
        return self._predict(*args, **kwargs)

 model_to_export = ExportHelper()
 ep = torch.export.export(
      model_to_export,
      args=(unnorm_coords, labels, unnorm_box, mask_input,  multimask_output),
      kwargs={"return_logits": return_logits},
      strict=False,
      )

結論#

在本教程中,我們學習瞭如何透過正確的配置和簡單的程式碼修改來解決挑戰,從而使用 torch.export 匯出流行用例的模型。一旦您能夠匯出模型,就可以將 ExportedProgram 降低到伺服器端的 AOTInductor 或邊緣裝置端的 ExecuTorch。要了解有關 AOTInductor (AOTI) 的更多資訊,請參閱 AOTI 教程。要了解有關 ExecuTorch 的更多資訊,請參閱 ExecuTorch 教程