評價此頁

torch.hub#

創建於: 2025年6月13日 | 最後更新於: 2025年6月13日

Pytorch Hub 是一個預訓練模型庫,旨在促進研究的可復現性。

釋出模型#

Pytorch Hub 支援透過新增一個簡單的 hubconf.py 檔案,將預訓練模型(模型定義和預訓練權重)釋出到 GitHub 倉庫;

hubconf.py 可以有多個入口點。每個入口點都定義為一個 Python 函式(例如:你想釋出的預訓練模型)。

  def entrypoint_name(*args, **kwargs):
      # args & kwargs are optional, for models which take positional/keyword arguments.
      ...

如何實現一個入口點?#

以下是一個程式碼片段,指定了 resnet18 模型的一個入口點,如果我們展開 pytorch/vision/hubconf.py 中的實現。在大多數情況下,匯入正確的函式 hubconf.py 即可。這裡我們只是想使用展開的版本作為示例來展示它是如何工作的。你可以在 pytorch/vision repo 中找到完整的指令碼。

  dependencies = ['torch']
  from torchvision.models.resnet import resnet18 as _resnet18

  # resnet18 is the name of entrypoint
  def resnet18(pretrained=False, **kwargs):
      """ # This docstring shows up in hub.help()
      Resnet18 model
      pretrained (bool): kwargs, load pretrained weights into the model
      """
      # Call the model, load pretrained weights
      model = _resnet18(pretrained=pretrained, **kwargs)
      return model
  • dependencies 變數是 **載入** 模型所需的包名 **列表**。請注意,這可能與訓練模型所需的依賴項略有不同。

  • argskwargs 會被傳遞給實際的可呼叫函式。

  • 函式的文件字串用作幫助訊息。它解釋了模型的作用以及允許的位置/關鍵字引數。強烈建議在此處新增一些示例。

  • 入口點函式可以返回一個模型(nn.module),或用於簡化使用者工作流程的輔助工具,例如 tokenizers。

  • 以下劃線開頭的可呼叫函式被視為輔助函式,它們不會顯示在 torch.hub.list() 中。

  • 預訓練權重可以儲存在 GitHub 倉庫的本地,也可以由 torch.hub.load_state_dict_from_url() 載入。如果小於 2GB,建議將其附加到 專案 release 並使用 release 中的 URL。在上面的示例中,torchvision.models.resnet.resnet18 處理了 pretrained,或者你可以在入口點定義中放入以下邏輯。

  if pretrained:
      # For checkpoint saved in local GitHub repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
      dirname = os.path.dirname(__file__)
      checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
      state_dict = torch.load(checkpoint)
      model.load_state_dict(state_dict)

      # For checkpoint saved elsewhere
      checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
      model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))

重要通知#

  • 釋出的模型至少應該在一個分支/標籤下。不能是隨機提交。

從 Hub 載入模型#

Pytorch Hub 提供了方便的 API,可以透過 torch.hub.list() 探索 Hub 中所有可用的模型,透過 torch.hub.help() 檢視文件字串和示例,並使用 torch.hub.load() 載入預訓練模型。

torch.hub.list(github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True)[source]#

列出在 github 指定的倉庫中所有可用的入口點。

引數
  • github (str) – 格式為“repo_owner/repo_name[:ref]”的字串,其中 ref 是可選的(標籤或分支)。如果未指定 ref,則假定預設分支為 main(如果存在),否則為 master。示例:‘pytorch/vision:0.10’

  • force_reload (bool, optional) – 是否丟棄現有快取並強制進行新的下載。預設為 False

  • skip_validation (bool, optional) – 如果為 False,torchhub 將檢查 github 引數指定的 ref 是否正確屬於倉庫所有者。這將向 GitHub API 發出請求;你可以透過設定 GITHUB_TOKEN 環境變數來指定非預設的 GitHub token。預設為 False

  • trust_repo (bool, str or None) –

    "check", True, False or None。此引數在 v1.12 中引入,有助於確保使用者只執行來自他們信任的倉庫的程式碼。

    • 如果為 False,將提示使用者是否信任該倉庫。

    • 如果為 True,該倉庫將被新增到信任列表中,並在無需顯式確認的情況下載入。

    • 如果為 "check",將檢查該倉庫是否在快取的信任倉庫列表中。如果不在該列表中,行為將回退到 trust_repo=False 選項。

    • 如果為 None:將發出警告,邀請使用者將 trust_repo 設定為 FalseTrue"check"。這僅為向後相容性存在,將在 v2.0 中刪除。

    預設為 None,並將在 v2.0 中更改為 "check"

  • verbose (bool, optional) – 如果為 False,則靜默關於命中本地快取的訊息。請注意,關於首次下載的訊息無法被靜默。預設為 True

返回

可用的入口點可呼叫函式

返回型別

列表

示例

>>> entrypoints = torch.hub.list("pytorch/vision", force_reload=True)
torch.hub.help(github, model, force_reload=False, skip_validation=False, trust_repo=None)[source]#

顯示入口點 model 的文件字串。

引數
  • github (str) – 格式為 <repo_owner/repo_name[:ref]> 的字串,其中 ref 是可選的(標籤或分支)。如果未指定 ref,則假定預設分支為 main(如果存在),否則為 master。示例:‘pytorch/vision:0.10’

  • model (str) – 倉庫 hubconf.py 中定義的入口點名稱。

  • force_reload (bool, optional) – 是否丟棄現有快取並強制進行新的下載。預設為 False

  • skip_validation (bool, optional) – 如果為 False,torchhub 將檢查 github 引數指定的 ref 是否正確屬於倉庫所有者。這將向 GitHub API 發出請求;你可以透過設定 GITHUB_TOKEN 環境變數來指定非預設的 GitHub token。預設為 False

  • trust_repo (bool, str or None) –

    "check", True, False or None。此引數在 v1.12 中引入,有助於確保使用者只執行來自他們信任的倉庫的程式碼。

    • 如果為 False,將提示使用者是否信任該倉庫。

    • 如果為 True,該倉庫將被新增到信任列表中,並在無需顯式確認的情況下載入。

    • 如果為 "check",將檢查該倉庫是否在快取的信任倉庫列表中。如果不在該列表中,行為將回退到 trust_repo=False 選項。

    • 如果為 None:將發出警告,邀請使用者將 trust_repo 設定為 FalseTrue"check"。這僅為向後相容性存在,將在 v2.0 中刪除。

    預設為 None,並將在 v2.0 中更改為 "check"

示例

>>> print(torch.hub.help("pytorch/vision", "resnet18", force_reload=True))
torch.hub.load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True, skip_validation=False, **kwargs)[source]#

從 GitHub 倉庫或本地目錄載入模型。

注意:載入模型是典型的用例,但也可以用於載入其他物件,如 tokenizers、損失函式等。

如果 source 為 ‘github’,則 repo_or_dir 預計格式為 repo_owner/repo_name[:ref],其中 ref 是可選的(標籤或分支)。

如果 source 為 ‘local’,則 repo_or_dir 預計是本地目錄的路徑。

引數
  • repo_or_dir (str) – 如果 source 為 ‘github’,則這應對應一個格式為 repo_owner/repo_name[:ref] 的 GitHub 倉庫,其中 ref 是可選的(標籤或分支),例如 ‘pytorch/vision:0.10’。如果未指定 ref,則假定預設分支為 main(如果存在),否則為 master。如果 source 為 ‘local’,則應是本地目錄的路徑。

  • model (str) – 倉庫/目錄的 hubconf.py 中定義的(入口點)可呼叫函式的名稱。

  • *args (optional) – 可呼叫函式 model 的相應引數。

  • source (str, optional) – ‘github’ 或 ‘local’。指定 repo_or_dir 的解釋方式。預設為 ‘github’。

  • trust_repo (bool, str or None) –

    "check", True, False or None。此引數在 v1.12 中引入,有助於確保使用者只執行來自他們信任的倉庫的程式碼。

    • 如果為 False,將提示使用者是否信任該倉庫。

    • 如果為 True,該倉庫將被新增到信任列表中,並在無需顯式確認的情況下載入。

    • 如果為 "check",將檢查該倉庫是否在快取的信任倉庫列表中。如果不在該列表中,行為將回退到 trust_repo=False 選項。

    • 如果為 None:將發出警告,邀請使用者將 trust_repo 設定為 FalseTrue"check"。這僅為向後相容性存在,將在 v2.0 中刪除。

    預設為 None,並將在 v2.0 中更改為 "check"

  • force_reload (bool, optional) – 是否無條件強制重新下載 GitHub 倉庫。如果 source = 'local',則沒有效果。預設為 False

  • verbose (bool, optional) – 如果為 False,則靜默關於命中本地快取的訊息。請注意,關於首次下載的訊息無法被靜默。如果 source = 'local',則沒有效果。預設為 True

  • skip_validation (bool, optional) – 如果為 False,torchhub 將檢查 github 引數指定的 ref 是否正確屬於倉庫所有者。這將向 GitHub API 發出請求;你可以透過設定 GITHUB_TOKEN 環境變數來指定非預設的 GitHub token。預設為 False

  • **kwargs (optional) – 可呼叫函式 model 的相應關鍵字引數。

返回

當使用給定的 *args**kwargs 呼叫 model 可呼叫函式時產生的輸出。

示例

>>> # from a github repo
>>> repo = "pytorch/vision"
>>> model = torch.hub.load(
...     repo, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1"
... )
>>> # from a local directory
>>> path = "/some/local/path/pytorch/vision"
>>> model = torch.hub.load(path, "resnet50", weights="ResNet50_Weights.DEFAULT")
torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)[source]#

將給定 URL 的物件下載到本地路徑。

引數
  • url (str) – 要下載的物件的 URL

  • dst (str) – 物件將要儲存的完整路徑,例如 /tmp/temporary_file

  • hash_prefix (str, optional) – 如果不為 None,則下載檔案的 SHA256 應該以 hash_prefix 開頭。預設為 None

  • progress (bool, optional) – 是否向 stderr 顯示進度條。預設為 True

示例

>>> torch.hub.download_url_to_file(
...     "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth",
...     "/tmp/temporary_file",
... )
torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None, weights_only=False)[source]#

從給定 URL 載入 Torch 序列化物件。

如果下載的檔案是 zip 檔案,它將被自動解壓縮。

如果物件已存在於 model_dir 中,則將其反序列化並返回。 model_dir 的預設值為 <hub_dir>/checkpoints,其中 hub_dir 是由 get_dir() 返回的目錄。

引數
  • url (str) – 要下載的物件的 URL

  • model_dir (str, optional) – 儲存物件的目錄

  • map_location (optional) – 指定如何重新對映儲存位置的函式或字典(參見 torch.load)

  • progress (bool, optional) – 是否向 stderr 顯示進度條。預設為 True

  • check_hash (bool, optional) – 如果為 True,URL 的檔名部分應遵循命名約定 filename-<sha256>.ext,其中 <sha256> 是檔案內容 SHA256 雜湊值的前八位或更多數字。雜湊值用於確保名稱的唯一性並驗證檔案內容。預設為 False

  • file_name (str, optional) – 下載檔案的名稱。如果未設定,將使用 URL 中的檔名。

  • weights_only (bool, optional) – 如果為 True,將只加載權重,不載入複雜的 pickled 物件。推薦用於不受信任的源。有關更多詳細資訊,請參閱 load()

返回型別

dict[str, Any]

示例

>>> state_dict = torch.hub.load_state_dict_from_url(
...     "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth"
... )

執行已載入的模型:#

請注意,torch.hub.load() 中的 *args**kwargs 用於 **例項化** 模型。載入模型後,如何找出模型可以做什麼?建議的工作流程是:

  • dir(model) 檢視模型的所有可用方法。

  • help(model.foo) 檢查 model.foo 需要什麼引數才能執行。

為了幫助使用者在不來回查閱文件的情況下進行探索,我們強烈建議倉庫所有者使函式幫助訊息清晰簡潔。包含一個最小工作示例也很有幫助。

我的下載的模型儲存在哪裡?#

使用的位置順序如下:

  • 呼叫 hub.set_dir(<PATH_TO_HUB_DIR>)

  • $TORCH_HOME/hub,如果設定了環境變數 TORCH_HOME

  • $XDG_CACHE_HOME/torch/hub,如果設定了環境變數 XDG_CACHE_HOME

  • ~/.cache/torch/hub

torch.hub.get_dir()[source]#

獲取用於儲存下載的模型和權重的 Torch Hub 快取目錄。

如果未呼叫 set_dir(),則預設路徑為 $TORCH_HOME/hub,其中環境變數 $TORCH_HOME 預設為 $XDG_CACHE_HOME/torch$XDG_CACHE_HOME 遵循 Linux 檔案系統佈局的 X Design Group 規範,如果未設定環境變數,則預設為 ~/.cache

返回型別

str

torch.hub.set_dir(d)[source]#

可選地設定用於儲存下載的模型和權重的 Torch Hub 目錄。

引數

d (str) – 用於儲存下載模型和權重的本地資料夾路徑。

快取邏輯#

預設情況下,我們載入檔案後不清理。Hub 預設使用快取,如果它已存在於 get_dir() 返回的目錄中。

使用者可以透過呼叫 hub.load(..., force_reload=True) 來強制重新載入。這將刪除現有的 GitHub 資料夾和下載的權重,並重新初始化新的下載。這在更新發布到同一分支時非常有用,使用者可以跟上最新的釋出。

已知限制:#

Torch Hub 的工作方式是將包視為已安裝一樣匯入。在 Python 中匯入會引入一些副作用。例如,你會在 Python 快取 sys.modulessys.path_importer_cache 中看到新的條目,這是正常的 Python 行為。這也意味著,如果不同的倉庫具有相同的子包名稱(通常是 model 子包),你可能會在從不同倉庫匯入不同模型時遇到匯入錯誤。這些匯入錯誤的解決方法是:從 sys.modules 字典中刪除有問題的子包;更多詳細資訊可以在 此 GitHub issue 中找到。

一個值得在此提及的已知限制是:使用者**不能**在**同一個 Python 程序**中載入同一個倉庫的兩個不同分支。這就像在 Python 中安裝兩個同名包一樣,這是不好的。如果你確實嘗試這樣做,快取可能會介入並給你帶來意外。當然,在單獨的程序中載入它們是完全沒問題的。