評價此頁

torch.load#

torch.load(f, map_location=None, pickle_module=pickle, *, weights_only=True, mmap=None, **pickle_load_args)[source]#

從檔案中載入使用 torch.save() 儲存的物件。

torch.load() 使用 Python 的反序列化功能,但會特殊處理張量底層的儲存。它們首先在 CPU 上反序列化,然後移動到儲存時的裝置。如果失敗(例如,因為執行時系統沒有某些裝置),則會引發異常。但是,可以使用 map_location 引數動態地將儲存重新對映到備用裝置集。

如果 map_location 是一個可呼叫物件,它將為每個序列化的儲存呼叫一次,有兩個引數:儲存和位置。儲存引數將是儲存的初始反序列化,駐留在 CPU 上。每個序列化的儲存都有一個關聯的位置標籤,該標籤標識了它儲存的裝置,並且該標籤是傳遞給 map_location 的第二個引數。內建的位置標籤是 CPU 張量的 'cpu',CUDA 張量的 'cuda:device_id'(例如 'cuda:2')。map_location 應返回 None 或一個儲存。如果 map_location 返回一個儲存,它將被用作最終反序列化的物件,並已移動到正確的裝置。否則,torch.load() 將回退到預設行為,就好像沒有指定 map_location 一樣。

如果 map_location 是一個 torch.device 物件或包含裝置標籤的字串,它將指示所有張量應載入到的位置。

否則,如果 map_location 是一個字典,它將用於將檔案中出現的位置標籤(鍵)重新對映到指定儲存位置(值)的標籤。

使用者擴充套件可以使用 torch.serialization.register_package() 註冊自己的位置標籤以及標記和反序列化方法。

有關操作 checkpoint 的更高階工具,請參閱 佈局控制

引數
  • f (Union[str, PathLike[str], IO[bytes]]) – 一個類似檔案的物件(必須實現 read()readline()tell()seek()),或包含檔名的字串或 os.PathLike 物件

  • map_location (Optional[Union[Callable[[Storage, str], Storage], device, str, dict[str, str]]]) – 一個函式、torch.device、字串或字典,用於指定如何重新對映儲存位置

  • pickle_module (Optional[Any]) – 用於反序列化元資料和物件的模組(必須與序列化檔案時使用的 pickle_module 匹配)

  • weights_only (Optional[bool]) – 指示反序列化器是否應僅限於載入張量、基本型別、字典以及透過 torch.serialization.add_safe_globals() 新增的任何型別。有關更多詳細資訊,請參閱 torch.load with weights_only=True

  • mmap (Optional[bool]) – 指示是否應對映檔案而不是將所有儲存載入到記憶體中。通常,檔案中的張量儲存將首先從磁碟移動到 CPU 記憶體,然後移動到儲存時標記的裝置,或者 map_location 指定的裝置。如果最終位置是 CPU,則此第二步為空操作。當設定 mmap 標誌時,而不是在第一步將張量儲存從磁碟複製到 CPU 記憶體,f 將被對映,這意味著張量儲存將在訪問其資料時惰性載入。

  • pickle_load_args (Any) – (僅限 Python 3) 傳遞給 pickle_module.load()pickle_module.Unpickler() 的可選關鍵字引數,例如 errors=...

返回型別

任何

警告

torch.load(),除非 weights_only 引數設定為 True,否則會隱式使用 pickle 模組,該模組已知不安全。有可能構造惡意的 pickle 資料,這些資料將在反序列化過程中執行任意程式碼。切勿在不安全模式下載入可能來自不受信任來源或可能被篡改的資料。**僅載入您信任的資料**。

注意

當你對包含 GPU 張量的檔案呼叫 torch.load() 時,預設情況下,這些張量將被載入到 GPU。你可以呼叫 torch.load(.., map_location='cpu') 然後呼叫 load_state_dict() 來避免在載入模型 checkpoint 時出現 GPU 記憶體激增。

注意

預設情況下,我們將位元組字串解碼為 utf-8。這是為了避免在 Python 3 中載入 Python 2 儲存的檔案時出現常見的錯誤情況 UnicodeDecodeError: 'ascii' codec can't decode byte 0x...。如果此預設值不正確,你可以使用額外的 encoding 關鍵字引數來指定如何載入這些物件,例如 encoding='latin1' 使用 latin1 編碼將它們解碼為字串,而 encoding='bytes' 將它們保留為位元組陣列,稍後可以用 byte_array.decode(...) 解碼。

示例

>>> torch.load("tensors.pt", weights_only=True)
# Load all tensors onto the CPU
>>> torch.load(
...     "tensors.pt",
...     map_location=torch.device("cpu"),
...     weights_only=True,
... )
# Load all tensors onto the CPU, using a function
>>> torch.load(
...     "tensors.pt",
...     map_location=lambda storage, loc: storage,
...     weights_only=True,
... )
# Load all tensors onto GPU 1
>>> torch.load(
...     "tensors.pt",
...     map_location=lambda storage, loc: storage.cuda(1),
...     weights_only=True,
... )  # type: ignore[attr-defined]
# Map tensors from GPU 1 to GPU 0
>>> torch.load(
...     "tensors.pt",
...     map_location={"cuda:1": "cuda:0"},
...     weights_only=True,
... )
# Load tensor from io.BytesIO object
# Loading from a buffer setting weights_only=False, warning this can be unsafe
>>> with open("tensor.pt", "rb") as f:
...     buffer = io.BytesIO(f.read())
>>> torch.load(buffer, weights_only=False)
# Load a module with 'ascii' encoding for unpickling
# Loading from a module setting weights_only=False, warning this can be unsafe
>>> torch.load("module.pt", encoding="ascii", weights_only=False)