評價此頁

序列化語義#

建立日期:2017年2月26日 | 最後更新日期:2025年6月23日

本文件描述瞭如何在Python中儲存和載入PyTorch張量和模組狀態,以及如何序列化Python模組以便在C++中載入。

儲存和載入張量#

torch.save()torch.load() 可以方便地儲存和載入張量。

>>> t = torch.tensor([1., 2.])
>>> torch.save(t, 'tensor.pt')
>>> torch.load('tensor.pt')
tensor([1., 2.])

按照慣例,PyTorch 檔案通常使用 '.pt' 或 '.pth' 副檔名。

torch.save()torch.load() 預設使用 Python 的 pickle,因此您也可以將多個張量作為 Python 物件(如元組、列表和字典)的一部分進行儲存。

>>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
>>> torch.save(d, 'tensor_dict.pt')
>>> torch.load('tensor_dict.pt')
{'a': tensor([1., 2.]), 'b': tensor([3., 4.])}

包含 PyTorch 張量的自定義資料結構也可以儲存,前提是該資料結構是可 pickle 的。

儲存和載入張量會保留檢視#

儲存張量會保留它們的檢視關係。

>>> numbers = torch.arange(1, 10)
>>> evens = numbers[1::2]
>>> torch.save([numbers, evens], 'tensors.pt')
>>> loaded_numbers, loaded_evens = torch.load('tensors.pt')
>>> loaded_evens *= 2
>>> loaded_numbers
tensor([ 1,  4,  3,  8,  5, 12,  7, 16,  9])

在後臺,這些張量共享相同的“儲存”。有關檢視和儲存的更多資訊,請參閱 張量檢視

當 PyTorch 儲存張量時,它會單獨儲存它們的儲存物件和張量元資料。這是可能在未來更改的實現細節,但它通常可以節省空間,並使 PyTorch 能夠輕鬆地重建載入張量之間的檢視關係。例如,在上面的程式碼片段中,只有一個儲存被寫入 'tensors.pt'。

然而,在某些情況下,儲存當前的儲存物件可能是不必要的,並會導致生成過大的檔案。在下面的程式碼片段中,為與 large 共享的儲存(包含 999 個元素)建立的檔案比 small 張量(僅包含 5 個元素)要大得多。

>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small, 'small.pt')
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
999

與將 small 張量中的五個元素儲存到 'small.pt' 不同,這裡儲存和載入的是 small 張量所共享的、包含 999 個元素的儲存。

當儲存的張量中的元素數量少於其儲存物件中的元素數量時,可以透過先克隆張量來減小儲存檔案的大小。克隆張量會建立一個新的張量,它具有一個新的儲存物件,僅包含張量中的值。

>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small.clone(), 'small.pt')  # saves a clone of small
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
5

但是,由於克隆的張量彼此獨立,因此它們不具有原始張量所具有的任何檢視關係。如果儲存的張量小於其儲存物件,並且檔案大小和檢視關係都很重要,那麼在儲存之前,必須小心地構造新的張量,以最小化其儲存物件的大小,同時仍保持所需的檢視關係。

儲存和載入 torch.nn.Modules#

另請參閱:教程:儲存和載入模型

在 PyTorch 中,模組的狀態通常使用“狀態字典”進行序列化。模組的狀態字典包含其所有引數和持久緩衝區。

>>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> list(bn.named_parameters())
[('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
 ('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]

>>> list(bn.named_buffers())
[('running_mean', tensor([0., 0., 0.])),
 ('running_var', tensor([1., 1., 1.])),
 ('num_batches_tracked', tensor(0))]

>>> bn.state_dict()
OrderedDict([('weight', tensor([1., 1., 1.])),
             ('bias', tensor([0., 0., 0.])),
             ('running_mean', tensor([0., 0., 0.])),
             ('running_var', tensor([1., 1., 1.])),
             ('num_batches_tracked', tensor(0))])

為保證相容性,建議不要直接儲存模組,而是隻儲存其狀態字典。Python 模組甚至有一個函式 load_state_dict(),用於從狀態字典恢復其狀態。

>>> torch.save(bn.state_dict(), 'bn.pt')
>>> bn_state_dict = torch.load('bn.pt')
>>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> new_bn.load_state_dict(bn_state_dict)
<All keys matched successfully>

請注意,狀態字典首先使用 torch.load() 從檔案中載入,然後使用 load_state_dict() 恢復狀態。

即使是自定義模組和包含其他模組的模組,也都有狀態字典,並可以使用此模式。

# A module with two linear layers
>>> class MyModule(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)

      def forward(self, input):
        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)

>>> m = MyModule()
>>> m.state_dict()
OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406],
                                   [-0.3289, 0.2827, 0.4588, 0.2031]])),
             ('l0.bias', tensor([ 0.0300, -0.1316])),
             ('l1.weight', tensor([[0.6533, 0.3413]])),
             ('l1.bias', tensor([-0.1112]))])

>>> torch.save(m.state_dict(), 'mymodule.pt')
>>> m_state_dict = torch.load('mymodule.pt')
>>> new_m = MyModule()
>>> new_m.load_state_dict(m_state_dict)
<All keys matched successfully>

序列化檔案格式 for torch.save#

自 PyTorch 1.6.0 起,torch.save 預設返回一個未壓縮的 ZIP64 存檔,除非使用者將 _use_new_zipfile_serialization 設定為 False

在此存檔中,檔案按以下順序排列:

checkpoint.pth
├── data.pkl
├── byteorder  # added in PyTorch 2.1.0
├── data/
│   ├── 0
│   ├── 1
│   ├── 2
│   └── …
└── version
條目如下:
  • data.pkl 是對傳遞給 torch.save 的物件進行 pickle 的結果,但不包含其中包含的 torch.Storage 物件。

  • byteorder 包含一個字串,其中是儲存時 sys.byteorder 的值(“little” 或 “big”)。

  • data/ 包含物件中的所有儲存,每個儲存是一個單獨的檔案。

  • version 包含儲存時的版本號,可以在載入時使用。

儲存時,PyTorch 將確保每個檔案的本地檔案頭都會填充到 64 位元組的倍數偏移量,從而確保每個檔案的偏移量都是 64 位元組對齊的。

注意

某些裝置(如 XLA)上的張量被序列化為 pickled numpy 陣列。因此,它們的儲存不會被序列化。在這種情況下,檢查點中可能不存在 data/ 目錄。

佈局控制#

torch.load() 中的 mmap 引數允許對張量儲存進行延遲載入。

此外,還有一些高階功能允許對 torch.save 檢查點進行更細粒度的控制和操作。

使用 torch.serialization.skip_data 上下文管理器可以:
  • 使用 torch.save 儲存一個包含資料位元組預留空間的檢查點,以便之後寫入。

  • 使用 torch.load 載入一個檢查點,並稍後填充張量的資料位元組。

要檢查 torch.save 檢查點中的張量元資料而不分配儲存資料記憶體,請在 FakeTensorMode 上下文管理器中使用 torch.load。除了跳過載入儲存資料(類似於上面的 skip_data)之外,它還會將儲存標記上其在檢查點內的偏移量,從而可以直接操作檢查點。

import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode

m = nn.Linear(10, 10)
torch.save(m.state_dict(), "checkpoint.pt")

with FakeTensorMode() as mode:
    fake_sd = torch.load("checkpoint.pt")

for k, v in fake_sd.items():
    print(f"key={k}, dtype={v.dtype}, shape={v.shape}, stride={v.stride()}, storage_offset={v.storage_offset()}")
    # offset of the storage in the checkpoint
    print(f"key={k}, checkpoint_offset={v.untyped_storage()._checkpoint_offset}")

有關更多資訊,此教程提供了使用這些功能操作檢查點的綜合示例。

torch.load with weights_only=True#

從 2.6 版本開始,如果未傳遞 pickle_module 引數,torch.load 將使用 weights_only=True

torch.load() 的文件中所述,weights_only=Truetorch.load 中使用的反 pickle 模組限制為僅執行 torch.Tensorsstate_dicts 以及其他一些基本型別所需的函式/類。此外,與 pickle 模組提供的預設 Unpickler 不同,weights_only Unpickler 不允許在反 pickling 過程中動態匯入任何內容。

如上所述,使用 torch.save 儲存模組的 state_dict 是最佳實踐。如果載入包含 nn.Module 的舊檢查點,我們建議使用 weights_only=False。載入包含張量子類(tensor subclasses)的檢查點時,很可能會有需要新增到白名單的函式/類,有關詳細資訊,請參閱下文。

如果 weights_only Unpickler 遇到一個預設未被白名單的函式或類,您應該會看到類似以下的錯誤訊息:

_pickle.UnpicklingError: Weights only load failed. This file can still be loaded,
to do so you have two options, do those steps only if you trust the source of the checkpoint.
    1. Re-running `torch.load` with `weights_only` set to `False` will likely succeed,
        but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
    2. Alternatively, to load with `weights_only=True` please check the recommended
       steps in the following error message.
       WeightsUnpickler error: Unsupported global: GLOBAL {__module__}.{__name__} was not an allowed global by
       default. Please use `torch.serialization.add_safe_globals([{__name__}])` or the
       `torch.serialization.safe_globals([{__name__}])` context manager to allowlist this global
       if you trust this class/function.

請按照錯誤訊息中的步驟操作,並僅在您信任這些函式或類時才將其新增到白名單。

要獲取檢查點中尚未被白名單的所有全域性變數(函式/類),您可以使用 torch.serialization.get_unsafe_globals_in_checkpoint(),它將返回一個字串列表,格式為 {__module__}.{__name__}。如果您信任這些函式/類,可以匯入它們並透過 torch.serialization.add_safe_globals() 或使用 torch.serialization.safe_globals 上下文管理器將它們新增到白名單。

要訪問使用者白名單的函式/類列表,可以使用 torch.serialization.get_safe_globals(),要清除當前列表,請參閱 torch.serialization.clear_safe_globals()

解決 weights_only 問題#

獲取不安全的全域性變數#

需要注意的是,torch.serialization.get_unsafe_globals_in_checkpoint() 會對檢查點進行靜態分析,某些型別可能在反 pickling 過程中動態構建,因此不會被 torch.serialization.get_unsafe_globals_in_checkpoint() 報告。其中一個例子是 numpy 中的 dtypes。在 numpy < 1.25 中,在將 torch.serialization.get_unsafe_globals_in_checkpoint() 報告的所有函式/類新增到白名單後,您可能會看到類似以下的錯誤:

WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtype[float32]'>

這可以透過 {add_}safe_globals([type(np.dtype(np.float32))]) 新增到白名單。

numpy >=1.25 中,您會看到:

WeightsUnpickler error: Can only build Tensor, Parameter, OrderedDict or types allowlisted via `add_safe_globals`,
but got <class 'numpy.dtypes.Float32DType'>

這可以透過 {add_}safe_globals([np.dtypes.Float32DType]) 新增到白名單。

環境變數#

有兩個環境變數會影響 torch.load 的行為。如果您無法訪問 torch.load 的呼叫點,這些變數會很有幫助。

  • TORCH_FORCE_WEIGHTS_ONLY_LOAD=1 將覆蓋所有 torch.load 的呼叫點,使其使用 weights_only=True

  • TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD=1 將使 torch.load 呼叫點使用 weights_only=False,**僅當** weights_only 未作為引數傳遞時。

實用函式#

以下實用函式與序列化相關:

torch.serialization.register_package(priority, tagger, deserializer)[source]#

註冊用於為具有關聯優先順序的儲存物件進行標記和反序列化的可呼叫物件。標記在儲存時將裝置與儲存物件關聯,而反序列化在載入時將儲存物件移動到合適的裝置。taggerdeserializer 將按其 priority 指定的順序執行,直到 tagger/deserializer 返回一個非 None 的值。

要覆蓋全域性登錄檔中某個裝置的反序列化行為,可以註冊一個優先順序高於現有 tagger 的 tagger。

此函式也可用於為新設備註冊 tagger 和 deserializer。

引數
返回

示例

>>> def ipu_tag(obj):
>>>     if obj.device.type == 'ipu':
>>>         return 'ipu'
>>> def ipu_deserialize(obj, location):
>>>     if location.startswith('ipu'):
>>>         ipu = getattr(torch, "ipu", None)
>>>         assert ipu is not None, "IPU device module is not loaded"
>>>         assert torch.ipu.is_available(), "ipu is not available"
>>>         return obj.ipu(location)
>>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
torch.serialization.get_crc32_options()[source]#

獲取 torch.save() 是否為每個記錄計算並寫入 crc32。

預設為 True

返回型別

布林值

torch.serialization.set_crc32_options(compute_crc32)[source]#

設定 torch.save() 是否為每個記錄計算並寫入 crc32。

注意

將其設定為 False 可能會導致 torch.save 輸出的解壓失敗或因 CRC32 損壞而發出警告。但是 torch.load 仍然可以載入檔案。

引數

compute_crc32 (bool) – 設定 crc32 計算標誌

torch.serialization.get_default_load_endianness()[source]#

獲取載入檔案的後備位元組序。

如果已儲存的檢查點中不存在位元組序標記,則使用此位元組序作為後備。預設情況下,它是“本地”(native)位元組序。

返回

Optional[LoadEndianness]

返回型別

default_load_endian

torch.serialization.set_default_load_endianness(endianness)[source]#

設定載入檔案的後備位元組序。

如果已儲存的檢查點中不存在位元組序標記,則使用此位元組序作為後備。預設情況下,它是“本地”(native)位元組序。

引數

endianness – 新的後備位元組序

torch.serialization.get_default_mmap_options()[source]#

獲取 torch.load()mmap=True 的預設 mmap 選項。

預設為 mmap.MAP_PRIVATE

返回

int

返回型別

default_mmap_options

torch.serialization.set_default_mmap_options(flags)[source]#

上下文管理器或函式,用於為 torch.load()mmap=True 設定預設 mmap 選項為 flags。

目前,只支援 mmap.MAP_PRIVATEmmap.MAP_SHARED。如果您需要新增其他選項,請提交一個 issue。

注意

此功能目前不支援 Windows。

引數

flags (int) – mmap.MAP_PRIVATEmmap.MAP_SHARED

torch.serialization.add_safe_globals(safe_globals)[source]#

將給定的全域性變數標記為 weights_only 載入是安全的。例如,新增到此列表中的函式可以在反 pickling 過程中被呼叫,類可以被例項化並設定狀態。

列表中的每個項可以是函式/類,或者是一個元組,形式為 (函式/類, 字串),其中字串是函式/類的完整路徑。

在序列化格式中,每個函式都用其完整路徑 {__module__}.{__qualname__} 來標識。呼叫此 API 時,您可以提供應與檢查點中的路徑匹配的完整路徑,否則將使用預設的 {fn.__module__}.{fn.__qualname__}

引數

safe_globals (List[Union[Callable, Tuple[Callable, str]]) – 要標記為安全的全域性變數列表

示例

>>> import tempfile
>>> class MyTensor(torch.Tensor):
...     pass
>>> t = MyTensor(torch.randn(2, 3))
>>> with tempfile.NamedTemporaryFile() as f:
...     torch.save(t, f.name)
# Running `torch.load(f.name, weights_only=True)` will fail with
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
...     torch.serialization.add_safe_globals([MyTensor])
...     torch.load(f.name, weights_only=True)
# MyTensor([[-0.5024, -1.8152, -0.5455],
#          [-0.8234,  2.0500, -0.3657]])
torch.serialization.clear_safe_globals()[source]#

清除對 weights_only 載入是安全的全域性變數列表。

torch.serialization.get_safe_globals()[source]#

返回使用者新增的、對 weights_only 載入是安全的全域性變數列表。

返回型別

list[Union[Callable, tuple[Callable, str]]]

torch.serialization.get_unsafe_globals_in_checkpoint(f)[source]#

返回 torch.save 物件中不安全(不適用於 weights_only 載入)的函式/類的字串列表。

對於給定的函式或類 f,相應的字串格式為 {f.__module__}.{f.__name__}

此函式將返回檢查點中所有不屬於 weights_only 安全集(透過 add_safe_globals()safe_globals 上下文或 torch 預設白名單)的全域性變數。

注意

此函式將靜態地反彙編檢查點中的 pickle 檔案。這意味著任何在反 pickling 過程中動態推送到堆疊的類都不會包含在輸出中。

引數

f (Union[str, PathLike[str], IO[bytes]]) – 透過 torch.save 儲存的檢查點物件的檔案類物件或字串。

返回

檢查點中的 pickle 全域性變數列表,這些變數未被 weights_only 白名單。

返回型別

list[str]

class torch.serialization.safe_globals(safe_globals)[source]#

上下文管理器,用於將某些全域性變數新增為 weights_only 載入的安全項。

引數

safe_globals (list[Union[Callable, tuple[Callable, str]]]) – 用於 weights_only 載入的全域性變數列表。

示例

>>> import tempfile
>>> class MyTensor(torch.Tensor):
...     pass
>>> t = MyTensor(torch.randn(2, 3))
>>> with tempfile.NamedTemporaryFile() as f:
...     torch.save(t, f.name)
# Running `torch.load(f.name, weights_only=True)` will fail with
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
...     with torch.serialization.safe_globals([MyTensor]):
...         torch.load(f.name, weights_only=True)
# MyTensor([[-0.5024, -1.8152, -0.5455],
#          [-0.8234,  2.0500, -0.3657]])
>>> assert torch.serialization.get_safe_globals() == []
class torch.serialization.skip_data(materialize_fake_tensors=False)[source]#

上下文管理器,用於跳過 torch.save / torch.load 呼叫中的儲存位元組的寫入/讀取。

對於儲存路徑,儲存仍會被儲存,但其位元組通常會寫入的空間將為空。然後可以在單獨的傳遞中填充儲存位元組。

對於載入路徑,張量將根據檢查點載入,但其儲存不會填充資料。

警告

skip_data 上下文管理器是一個早期原型,可能會發生更改。

引數

materialize_fake_tensors (bool) – 儲存時是否具體化 FakeTensors。這對載入路徑是無操作。

示例

>>> import tempfile
>>> t = torch.randn(2, 3)
>>> with tempfile.NamedTemporaryFile() as f:
...     with torch.serialization.skip_data():
...         torch.save(t, f.name)
...     torch.load(f.name, weights_only=True)
tensor([[0., 0., 0.],
        [0., 0., 0.]])

配置#

torch.utils.serialization.config 提供了一個全域性配置,可以控制 torch.savetorch.load 的行為。

torch.utils.serialization.config.save 包含控制 torch.save 行為的選項。

  • compute_crc32: 是否計算並寫入 zip 檔案校驗和 (預設: True)。請參閱 set_crc32_options()

  • use_pinned_memory_for_d2h: 對於傳遞到 torch.save 時位於加速器上的儲存,是否在 torch.save 中將儲存移動到 CPU 的固定記憶體或可分頁記憶體。(預設: False (即可分頁))

  • storage_alignment: 在 torch.save 期間,檢查點中儲存的對齊位元組數。(預設 64)

torch.utils.serialization.config.load 包含控制 torch.load 行為的選項。

  • mmap: 請參閱 torch.load()mmap 引數的文件。此配置將設定 torch.loadmmap 行為,如果它沒有被顯式傳遞給 torch.load 呼叫的話 (預設: False)。

  • endianness: 請參閱 set_default_load_endianness()。(預設: torch.serialization.LoadEndianness.NATIVE)

  • mmap_flags: 請參閱 set_default_mmap_options。(預設: MAP_PRIVATE)

  • calculate_storage_offsets: 如果此配置設定為 True,則在使用 torch.load(mmap=True) 時,將計算儲存的偏移量,而不是透過隨機讀取來獲取。這可以最大限度地減少隨機讀取,當檔案透過網路載入時可能會很有用。(預設: False)