快捷方式

QueryModule

class torchrl.data.QueryModule(*args, **kwargs)[原始碼]

一個用於生成相容儲存索引的模組。

一個查詢儲存並返回該儲存所需索引的模組。目前,它只輸出整數索引(torch.int64)。

引數:
  • in_keys (list of NestedKeys) – 將用於生成雜湊值的輸入 tensordict 的鍵。

  • index_key (NestedKey) – 將寫入索引值的輸出鍵。預設為 "_index"

關鍵字引數:
  • hash_key (NestedKey) – 將寫入雜湊值的輸出鍵。預設為 "_hash"

  • hash_module (Callable[[Any], int] or a list of these, optional) – 一個類似於 SipHash(預設)的雜湊模組。如果提供了一個可呼叫物件列表,其長度必須等於 in_keys 的數量。

  • hash_to_int (Callable[[int], int], optional) – 一個有狀態的函式,將雜湊值對映到一個非負整數,該整數對應於儲存中的索引。預設為 HashToInt

  • aggregator (Callable[[int], int], optional) – 一個用於將多個雜湊值組合在一起的雜湊函式。當有多個 in_keys 時,才應傳遞此引數。如果提供了一個 hash_module 但未傳遞 aggregator,則它將採用 hash_module 的值。如果未提供 hash_module 或提供了 hash_modules 的列表但未傳遞 aggregator,則將預設為 SipHash

  • clone (bool, optional) – 如果為 True,則將返回輸入 TensorDict 的淺複製。這可用於檢索儲存中對應於給定輸入 tensordict 的整數索引。這可以透過在 forward 方法中提供 clone 引數來覆蓋。預設為 False

示例

>>> query_module = QueryModule(
...     in_keys=["key1", "key2"],
...     index_key="index",
...     hash_module=SipHash(),
... )
>>> query = TensorDict(
...     {
...         "key1": torch.Tensor([[1], [1], [1], [2]]),
...         "key2": torch.Tensor([[3], [3], [2], [3]]),
...         "other": torch.randn(4),
...     },
...     batch_size=(4,),
... )
>>> res = query_module(query)
>>> # The first two pairs of key1 and key2 match
>>> assert res["index"][0] == res["index"][1]
>>> # The last three pairs of key1 and key2 have at least one mismatching value
>>> assert res["index"][1] != res["index"][2]
>>> assert res["index"][2] != res["index"][3]
add_module(name: str, module: Optional[Module]) None

將子模組新增到當前模組。

可以使用給定的名稱作為屬性訪問該模組。

引數:
  • name (str) – 子模組的名稱。子模組可以透過給定名稱從此模組訪問

  • module (Module) – 要新增到模組中的子模組。

apply(fn: Callable[[Module], None]) Self

fn 遞迴應用於每個子模組(由 .children() 返回)以及自身。

典型用法包括初始化模型引數(另請參閱 torch.nn.init)。

引數:

fn (Module -> None) – 要應用於每個子模組的函式

返回:

self

返回型別:

模組

示例

>>> @torch.no_grad()
>>> def init_weights(m):
>>>     print(m)
>>>     if type(m) == nn.Linear:
>>>         m.weight.fill_(1.0)
>>>         print(m.weight)
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> net.apply(init_weights)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)
bfloat16() Self

將所有浮點引數和緩衝區轉換為 bfloat16 資料型別。

注意

此方法就地修改模組。

返回:

self

返回型別:

模組

buffers(recurse: bool = True) Iterator[Tensor]

返回模組緩衝區的迭代器。

引數:

recurse (bool) – 如果為 True,則會產生此模組及其所有子模組的 buffer。否則,僅會產生此模組的直接成員 buffer。

產生:

torch.Tensor – 模組緩衝區

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> for buf in model.buffers():
>>>     print(type(buf), buf.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
children() Iterator[Module]

返回直接子模組的迭代器。

產生:

Module – 子模組

compile(*args, **kwargs)

使用 torch.compile() 編譯此 Module 的前向傳播。

此 Module 的 __call__ 方法將被編譯,並且所有引數將按原樣傳遞給 torch.compile()

有關此函式的引數的詳細資訊,請參閱 torch.compile()

cpu() Self

將所有模型引數和緩衝區移動到 CPU。

注意

此方法就地修改模組。

返回:

self

返回型別:

模組

cuda(device: Optional[Union[device, int]] = None) Self

將所有模型引數和緩衝區移動到 GPU。

這也會使相關的引數和緩衝區成為不同的物件。因此,如果模組在最佳化時將駐留在 GPU 上,則應在構建最佳化器之前呼叫此函式。

注意

此方法就地修改模組。

引數:

device (int, optional) – 如果指定,所有引數將複製到該裝置

返回:

self

返回型別:

模組

double() Self

將所有浮點引數和緩衝區轉換為 double 資料型別。

注意

此方法就地修改模組。

返回:

self

返回型別:

模組

eval() Self

將模組設定為評估模式。

這僅對某些模組有影響。有關模組在訓練/評估模式下的行為,例如它們是否受影響(如 DropoutBatchNorm 等),請參閱具體模組的文件。

這等同於 self.train(False)

有關 .eval() 和幾種可能與之混淆的類似機制之間的比較,請參閱 區域性停用梯度計算

返回:

self

返回型別:

模組

extra_repr() str

返回模組的額外表示。

要列印自定義額外資訊,您應該在自己的模組中重新實現此方法。單行和多行字串均可接受。

float() Self

將所有浮點引數和緩衝區轉換為 float 資料型別。

注意

此方法就地修改模組。

返回:

self

返回型別:

模組

forward(tensordict: TensorDictBase, *, extend: bool = True, write_hash: bool = True, clone: bool | None = None) TensorDictBase[原始碼]

定義每次呼叫時執行的計算。

所有子類都應重寫此方法。

注意

儘管前向傳播的實現需要在此函式中定義,但您應該在之後呼叫 Module 例項而不是此函式,因為前者會處理註冊的鉤子,而後者則會靜默忽略它們。

get_buffer(target: str) Tensor

返回由 target 給定的緩衝區(如果存在),否則丟擲錯誤。

有關此方法功能的更詳細解釋以及如何正確指定 target,請參閱 get_submodule 的文件字串。

引數:

target – 要查詢的 buffer 的完全限定字串名稱。(要指定完全限定字串,請參閱 get_submodule。)

返回:

target 引用的緩衝區

返回型別:

torch.Tensor

丟擲:

AttributeError – 如果目標字串引用了無效路徑或解析為非 buffer 物件。

get_extra_state() Any

返回要包含在模組 state_dict 中的任何額外狀態。

實現此功能以及相應的 set_extra_state() 來儲存額外狀態。

注意,為了保證 state_dict 的序列化工作正常,額外狀態應該是可被 pickle 的。我們僅為 Tensors 的序列化提供向後相容性保證;其他物件的序列化形式若發生變化,可能導致向後相容性中斷。

返回:

要儲存在模組 state_dict 中的任何額外狀態

返回型別:

物件

get_parameter(target: str) Parameter

如果存在,返回由 target 給定的引數,否則丟擲錯誤。

有關此方法功能的更詳細解釋以及如何正確指定 target,請參閱 get_submodule 的文件字串。

引數:

target – 要查詢的 Parameter 的完全限定字串名稱。(要指定完全限定字串,請參閱 get_submodule。)

返回:

target 引用的引數

返回型別:

torch.nn.Parameter

丟擲:

AttributeError – 如果目標字串引用了無效路徑或解析為非 nn.Parameter 的物件。

get_submodule(target: str) Module

如果存在,返回由 target 給定的子模組,否則丟擲錯誤。

例如,假設您有一個 nn.Module A,它看起來像這樣

A(
    (net_b): Module(
        (net_c): Module(
            (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
        )
        (linear): Linear(in_features=100, out_features=200, bias=True)
    )
)

(圖示了一個 nn.Module AA 包含一個巢狀子模組 net_b,該子模組本身有兩個子模組 net_clinearnet_c 隨後又有一個子模組 conv。)

要檢查是否存在 linear 子模組,可以呼叫 get_submodule("net_b.linear")。要檢查是否存在 conv 子模組,可以呼叫 get_submodule("net_b.net_c.conv")

get_submodule 的執行時複雜度受 target 中模組巢狀深度的限制。與 named_modules 的查詢相比,後者的複雜度是按傳遞模組數量計算的 O(N)。因此,對於簡單地檢查某個子模組是否存在,應始終使用 get_submodule

引數:

target – 要查詢的子模組的完全限定字串名稱。(要指定完全限定字串,請參閱上面的示例。)

返回:

target 引用的子模組

返回型別:

torch.nn.Module

丟擲:

AttributeError – 如果在目標字串解析的任何路徑中,子路徑解析為不存在的屬性名或不是 nn.Module 例項的物件。

half() Self

將所有浮點引數和緩衝區轉換為 half 資料型別。

注意

此方法就地修改模組。

返回:

self

返回型別:

模組

ipu(device: Optional[Union[device, int]] = None) Self

將所有模型引數和緩衝區移動到 IPU。

這也會使關聯的引數和緩衝區成為不同的物件。因此,如果模組在最佳化時將駐留在 IPU 上,則應在構建最佳化器之前呼叫它。

注意

此方法就地修改模組。

引數:

device (int, optional) – 如果指定,所有引數將複製到該裝置

返回:

self

返回型別:

模組

static is_tdmodule_compatible(module)

檢查模組是否與 TensorDictModule API 相容。

load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False)

state_dict 複製引數和緩衝區到此模組及其子模組。

如果 strictTrue,那麼 state_dict 的鍵必須與此模組的 state_dict() 函式返回的鍵完全匹配。

警告

如果 assignTrue,則最佳化器必須在呼叫 load_state_dict 之後建立,除非 get_swap_module_params_on_conversion()True

引數:
  • state_dict (dict) – 包含引數和持久 buffer 的字典。

  • strict (bool, optional) – 是否嚴格強制 state_dict 中的鍵與此模組的 state_dict() 函式返回的鍵匹配。預設為 True

  • assign (bool, optional) – 當設定為 False 時,將保留當前模組中張量的屬性;當設定為 True 時,將保留 state_dict 中張量的屬性。唯一的例外是 Parameterrequires_grad 欄位,此時將保留模組的值。預設值:False

返回:

  • missing_keys 是一個包含此模組期望但

    在提供的 state_dict 中缺失的任何鍵的字串列表。

  • unexpected_keys 是一個字串列表,包含此模組

    不期望但在提供的 state_dict 中存在的鍵。

返回型別:

NamedTuple,包含 missing_keysunexpected_keys 欄位。

注意

如果引數或緩衝區被註冊為 None 且其對應的鍵存在於 state_dict 中,load_state_dict() 將引發 RuntimeError

modules() Iterator[Module]

返回網路中所有模組的迭代器。

產生:

Module – 網路中的一個模組

注意

重複的模組只返回一次。在以下示例中,l 只返回一次。

示例

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.modules()):
...     print(idx, '->', m)

0 -> Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)
1 -> Linear(in_features=2, out_features=2, bias=True)
mtia(device: Optional[Union[device, int]] = None) Self

將所有模型引數和緩衝區移動到 MTIA。

這也會使關聯的引數和緩衝區成為不同的物件。因此,如果模組在最佳化時將駐留在 MTIA 上,則應在構建最佳化器之前呼叫它。

注意

此方法就地修改模組。

引數:

device (int, optional) – 如果指定,所有引數將複製到該裝置

返回:

self

返回型別:

模組

named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[tuple[str, torch.Tensor]]

返回模組緩衝區上的迭代器,同時生成緩衝區的名稱和緩衝區本身。

引數:
  • prefix (str) – 為所有 buffer 名稱新增字首。

  • recurse (bool, optional) – 如果為 True,則會生成此模組及其所有子模組的 buffers。否則,僅生成此模組直接成員的 buffers。預設為 True。

  • remove_duplicate (bool, optional) – 是否在結果中刪除重複的 buffers。預設為 True。

產生:

(str, torch.Tensor) – 包含名稱和緩衝區的元組

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>>     if name in ['running_var']:
>>>         print(buf.size())
named_children() Iterator[tuple[str, 'Module']]

返回對直接子模組的迭代器,生成模組的名稱和模組本身。

產生:

(str, Module) – 包含名稱和子模組的元組

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, module in model.named_children():
>>>     if name in ['conv4', 'conv5']:
>>>         print(module)
named_modules(memo: Optional[set['Module']] = None, prefix: str = '', remove_duplicate: bool = True)

返回網路中所有模組的迭代器,同時生成模組的名稱和模組本身。

引數:
  • memo – 用於儲存已新增到結果中的模組集合的 memo

  • prefix – 將新增到模組名稱的名稱字首

  • remove_duplicate – 是否從結果中刪除重複的模組例項

產生:

(str, Module) – 名稱和模組的元組

注意

重複的模組只返回一次。在以下示例中,l 只返回一次。

示例

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.named_modules()):
...     print(idx, '->', m)

0 -> ('', Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
))
1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[tuple[str, torch.nn.parameter.Parameter]]

返回模組引數的迭代器,同時生成引數的名稱和引數本身。

引數:
  • prefix (str) – 為所有引數名稱新增字首。

  • recurse (bool) – 如果為 True,則會生成此模組及其所有子模組的引數。否則,僅生成此模組直接成員的引數。

  • remove_duplicate (bool, optional) – 是否在結果中刪除重複的引數。預設為 True。

產生:

(str, Parameter) – 包含名稱和引數的元組

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
parameters(recurse: bool = True) Iterator[Parameter]

返回模組引數的迭代器。

這通常傳遞給最佳化器。

引數:

recurse (bool) – 如果為 True,則會生成此模組及其所有子模組的引數。否則,僅生成此模組直接成員的引數。

產生:

Parameter – 模組引數

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
>>>     print(type(param), param.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
register_backward_hook(hook: Callable[[Module, Union[tuple[torch.Tensor, ...], Tensor], Union[tuple[torch.Tensor, ...], Tensor]], Union[None, tuple[torch.Tensor, ...], Tensor]]) RemovableHandle

在模組上註冊一個反向傳播鉤子。

此函式已棄用,建議使用 register_full_backward_hook(),並且此函式在未來版本中的行為將發生變化。

返回:

一個控制代碼,可用於透過呼叫 handle.remove() 來移除新增的鉤子

返回型別:

torch.utils.hooks.RemovableHandle

register_buffer(name: str, tensor: Optional[Tensor], persistent: bool = True) None

向模組新增一個緩衝區。

這通常用於註冊一個不應被視為模型引數的緩衝區。例如,BatchNorm 的 running_mean 不是引數,但它是模組狀態的一部分。緩衝區預設是持久的,並將與引數一起儲存。透過將 persistent 設定為 False 可以改變此行為。持久緩衝區和非持久緩衝區之間的唯一區別是後者不會成為此模組 state_dict 的一部分。

可以使用給定名稱作為屬性訪問緩衝區。

引數:
  • name (str) – buffer 的名稱。可以使用給定的名稱從此模組訪問 buffer

  • tensor (Tensor or None) – 要註冊的緩衝區。如果為 None,則忽略在緩衝區上執行的操作,例如 cuda。如果為 None,則該緩衝區 **不** 包含在模組的 state_dict 中。

  • persistent (bool) – 緩衝區是否是此模組 state_dict 的一部分。

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> self.register_buffer('running_mean', torch.zeros(num_features))
register_forward_hook(hook: Union[Callable[[T, tuple[Any, ...], Any], Optional[Any]], Callable[[T, tuple[Any, ...], dict[str, Any], Any], Optional[Any]]], *, prepend: bool = False, with_kwargs: bool = False, always_call: bool = False) RemovableHandle

在模組上註冊一個前向鉤子。

在每次 forward() 計算出輸出後,都會呼叫此鉤子。

如果 with_kwargsFalse 或未指定,則輸入僅包含傳遞給模組的位置引數。關鍵字引數不會傳遞給鉤子,只會傳遞給 forward。鉤子可以修改輸出。它可以就地修改輸入,但不會影響 forward,因為它是在 forward() 呼叫之後呼叫的。鉤子應具有以下簽名

hook(module, args, output) -> None or modified output

如果 with_kwargsTrue,則前向鉤子將接收傳遞給 forward 函式的 kwargs,並需要返回可能已修改的輸出。鉤子應該具有以下簽名

hook(module, args, kwargs, output) -> None or modified output
引數:
  • hook (Callable) – 使用者定義的待註冊鉤子。

  • prepend (bool) – 如果為 True,則提供的 hook 將在對此 torch.nn.Module 的所有現有 forward 鉤子之前觸發。否則,提供的 hook 將在所有現有 forward 鉤子之後觸發。請注意,使用 register_module_forward_hook() 註冊的全域性 forward 鉤子將在透過此方法註冊的所有鉤子之前觸發。預設為 False

  • with_kwargs (bool) – 如果為 True,則 hook 將接收傳遞給 forward 函式的 kwargs。預設為 False

  • always_call (bool) – 如果為 True,則無論在呼叫 Module 時是否引發異常,都會執行 hook。預設為 False

返回:

一個控制代碼,可用於透過呼叫 handle.remove() 來移除新增的鉤子

返回型別:

torch.utils.hooks.RemovableHandle

register_forward_pre_hook(hook: Union[Callable[[T, tuple[Any, ...]], Optional[Any]], Callable[[T, tuple[Any, ...], dict[str, Any]], Optional[tuple[Any, dict[str, Any]]]]], *, prepend: bool = False, with_kwargs: bool = False) RemovableHandle

在模組上註冊一個前向預鉤子。

在每次呼叫 forward() 之前,都會呼叫此鉤子。

如果 with_kwargs 為 false 或未指定,則輸入僅包含傳遞給模組的位置引數。關鍵字引數不會傳遞給鉤子,而只會傳遞給 forward。鉤子可以修改輸入。使用者可以返回一個元組或單個修改後的值。我們將把值包裝成一個元組,如果返回的是單個值(除非該值本身就是元組)。鉤子應該具有以下簽名

hook(module, args) -> None or modified input

如果 with_kwargs 為 true,則前向預鉤子將接收傳遞給 forward 函式的 kwargs。如果鉤子修改了輸入,則應該返回 args 和 kwargs。鉤子應該具有以下簽名

hook(module, args, kwargs) -> None or a tuple of modified input and kwargs
引數:
  • hook (Callable) – 使用者定義的待註冊鉤子。

  • prepend (bool) – 如果為 True,則提供的 hook 將在對此 torch.nn.Module 的所有現有 forward_pre 鉤子之前觸發。否則,提供的 hook 將在所有現有 forward_pre 鉤子之後觸發。請注意,使用 register_module_forward_pre_hook() 註冊的全域性 forward_pre 鉤子將在透過此方法註冊的所有鉤子之前觸發。預設為 False

  • with_kwargs (bool) – 如果為 True,則 hook 將接收傳遞給 forward 函式的 kwargs。預設為 False

返回:

一個控制代碼,可用於透過呼叫 handle.remove() 來移除新增的鉤子

返回型別:

torch.utils.hooks.RemovableHandle

register_full_backward_hook(hook: Callable[[Module, Union[tuple[torch.Tensor, ...], Tensor], Union[tuple[torch.Tensor, ...], Tensor]], Union[None, tuple[torch.Tensor, ...], Tensor]], prepend: bool = False) RemovableHandle

在模組上註冊一個反向傳播鉤子。

每次計算相對於模組的梯度時,將呼叫此鉤子,其觸發規則如下:

  1. 通常,鉤子在計算相對於模組輸入的梯度時觸發。

  2. 如果模組輸入都不需要梯度,則在計算相對於模組輸出的梯度時觸發鉤子。

  3. 如果模組輸出都不需要梯度,則鉤子將不觸發。

鉤子應具有以下簽名

hook(module, grad_input, grad_output) -> tuple(Tensor) or None

The grad_input and grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of grad_input in subsequent computations. grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in grad_input and grad_output will be None for all non-Tensor arguments.

由於技術原因,當此鉤子應用於模組時,其前向函式將接收傳遞給模組的每個張量的檢視。類似地,呼叫者將接收模組前向函式返回的每個張量的檢視。

警告

使用反向傳播鉤子時不允許就地修改輸入或輸出,否則將引發錯誤。

引數:
  • hook (Callable) – 要註冊的使用者定義鉤子。

  • prepend (bool) – If true, the provided hook will be fired before all existing backward hooks on this torch.nn.Module. Otherwise, the provided hook will be fired after all existing backward hooks on this torch.nn.Module. Note that global backward hooks registered with register_module_full_backward_hook() will fire before all hooks registered by this method.

返回:

一個控制代碼,可用於透過呼叫 handle.remove() 來移除新增的鉤子

返回型別:

torch.utils.hooks.RemovableHandle

register_full_backward_pre_hook(hook: Callable[[Module, Union[tuple[torch.Tensor, ...], Tensor]], Union[None, tuple[torch.Tensor, ...], Tensor]], prepend: bool = False) RemovableHandle

在模組上註冊一個反向預鉤子。

每次計算模組的梯度時,將呼叫此鉤子。鉤子應具有以下簽名

hook(module, grad_output) -> tuple[Tensor] or None

grad_output 是一個元組。鉤子不應修改其引數,但可以選擇返回一個新的輸出梯度,該梯度將取代 grad_output 用於後續計算。對於所有非 Tensor 引數,grad_output 中的條目將為 None

由於技術原因,當此鉤子應用於模組時,其前向函式將接收傳遞給模組的每個張量的檢視。類似地,呼叫者將接收模組前向函式返回的每個張量的檢視。

警告

使用反向傳播鉤子時不允許就地修改輸入,否則將引發錯誤。

引數:
  • hook (Callable) – 要註冊的使用者定義鉤子。

  • prepend (bool) – If true, the provided hook will be fired before all existing backward_pre hooks on this torch.nn.Module. Otherwise, the provided hook will be fired after all existing backward_pre hooks on this torch.nn.Module. Note that global backward_pre hooks registered with register_module_full_backward_pre_hook() will fire before all hooks registered by this method.

返回:

一個控制代碼,可用於透過呼叫 handle.remove() 來移除新增的鉤子

返回型別:

torch.utils.hooks.RemovableHandle

register_load_state_dict_post_hook(hook)

註冊一個後鉤子,用於在模組的 load_state_dict() 被呼叫後執行。

它應該具有以下簽名:

hook(module, incompatible_keys) -> None

The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys. missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys.

如果需要,可以就地修改給定的 incompatible_keys。

Note that the checks performed when calling load_state_dict() with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys, as expected. Additions to either set of keys will result in an error being thrown when strict=True, and clearing out both missing and unexpected keys will avoid an error.

返回:

一個控制代碼,可用於透過呼叫 handle.remove() 來移除新增的鉤子

返回型別:

torch.utils.hooks.RemovableHandle

register_load_state_dict_pre_hook(hook)

註冊一個預鉤子,用於在模組的 load_state_dict() 被呼叫之前執行。

它應該具有以下簽名:

hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950

引數:

hook (Callable) – 在載入狀態字典之前將呼叫的可呼叫鉤子。

register_module(name: str, module: Optional[Module]) None

add_module() 的別名。

register_parameter(name: str, param: Optional[Parameter]) None

向模組新增一個引數。

可以使用給定名稱作為屬性訪問該引數。

引數:
  • name (str) – 引數的名稱。可以透過給定名稱從該模組訪問該引數。

  • param (Parameter or None) – 要新增到模組的引數。如果為 None,則忽略在引數上執行的操作,例如 cuda。如果為 None,則該引數 **不** 包含在模組的 state_dict 中。

register_state_dict_post_hook(hook)

註冊 state_dict() 方法的後置鉤子。

它應該具有以下簽名:

hook(module, state_dict, prefix, local_metadata) -> None

註冊的鉤子可以就地修改 state_dict

register_state_dict_pre_hook(hook)

註冊 state_dict() 方法的前置鉤子。

它應該具有以下簽名:

hook(module, prefix, keep_vars) -> None

註冊的鉤子可用於在進行 state_dict 呼叫之前執行預處理。

requires_grad_(requires_grad: bool = True) Self

更改自動梯度是否應記錄此模組中引數的操作。

此方法就地設定引數的 requires_grad 屬性。

此方法有助於凍結模組的一部分以進行微調或單獨訓練模型的一部分(例如,GAN 訓練)。

請參閱 本地停用梯度計算 以比較 .requires_grad_() 和幾種可能與之混淆的類似機制。

引數:

requires_grad (bool) – 自動求導是否應記錄此模組上的引數操作。預設為 True

返回:

self

返回型別:

模組

reset_out_keys()

out_keys 屬性重置為其原始值。

返回: 相同的模組,但 out_keys 值已重置。

示例

>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> import torch
>>> mod = TensorDictModule(lambda x, y: (x+2, y+2), in_keys=["a", "b"], out_keys=["c", "d"])
>>> mod.select_out_keys("d")
>>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, [])
>>> mod(td)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> mod.reset_out_keys()
>>> mod(td)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
reset_parameters_recursive(parameters: Optional[TensorDictBase] = None) Optional[TensorDictBase]

遞迴地重置模組及其子模組的引數。

引數:

parameters (TensorDict of parameters, optional) – 如果設定為 None,則模組將使用 self.parameters() 重置。否則,我們將就地重置 tensordict 中的引數。這對於引數本身不儲存在模組中的函式式模組很有用。

返回:

新引數的 tensordict,僅當 parameters 不為 None 時返回。

示例

>>> from tensordict.nn import TensorDictModule
>>> from torch import nn
>>> net = nn.Sequential(nn.Linear(2,3), nn.ReLU())
>>> old_param = net[0].weight.clone()
>>> module = TensorDictModule(net, in_keys=['bork'], out_keys=['dork'])
>>> module.reset_parameters()
>>> (old_param == net[0].weight).any()
tensor(False)

此方法還支援函式式引數取樣

>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> from torch import nn
>>> net = nn.Sequential(nn.Linear(2,3), nn.ReLU())
>>> module = TensorDictModule(net, in_keys=['bork'], out_keys=['dork'])
>>> params = TensorDict.from_module(module)
>>> old_params = params.clone(recurse=True)
>>> module.reset_parameters(params)
>>> (old_params == params).any()
False
select_out_keys(*out_keys) TensorDictModuleBase

選擇將在輸出 tensordict 中找到的鍵。

當一個人想丟棄複雜圖中的中間鍵,或者當這些鍵的存在可能觸發意外行為時,這很有用。

原始 out_keys 仍然可以透過 module.out_keys_source 訪問。

引數:

*out_keys (字串序列字串元組) – 應在輸出 tensordict 中找到的 out_keys。

返回: 相同的模組,以就地修改方式返回,並更新了 out_keys

最簡單的用法是與 TensorDictModule 一起使用。

示例

>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> import torch
>>> mod = TensorDictModule(lambda x, y: (x+2, y+2), in_keys=["a", "b"], out_keys=["c", "d"])
>>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, [])
>>> mod(td)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> mod.select_out_keys("d")
>>> td = TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, [])
>>> mod(td)
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

此功能也將適用於分派的引數: .. rubric:: 示例

>>> mod(torch.zeros(()), torch.ones(()))
tensor(2.)

此更改將原地發生(即返回的同一模組將具有更新的 out_keys 列表)。您可以使用 TensorDictModuleBase.reset_out_keys() 方法撤消此操作。

示例

>>> mod.reset_out_keys()
>>> mod(TensorDict({"a": torch.zeros(()), "b": torch.ones(())}, []))
TensorDict(
    fields={
        a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

這也將適用於其他類,例如 Sequential: .. rubric:: 示例

>>> from tensordict.nn import TensorDictSequential
>>> seq = TensorDictSequential(
...     TensorDictModule(lambda x: x+1, in_keys=["x"], out_keys=["y"]),
...     TensorDictModule(lambda x: x+1, in_keys=["y"], out_keys=["z"]),
... )
>>> td = TensorDict({"x": torch.zeros(())}, [])
>>> seq(td)
TensorDict(
    fields={
        x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        y: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> seq.select_out_keys("z")
>>> td = TensorDict({"x": torch.zeros(())}, [])
>>> seq(td)
TensorDict(
    fields={
        x: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        z: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
set_extra_state(state: Any) None

設定載入的 state_dict 中包含的額外狀態。

此函式從 load_state_dict() 呼叫,以處理 state_dict 中找到的任何額外狀態。實現此功能以及相應的 get_extra_state() 來儲存額外狀態。

引數:

state (dict) – 來自 state_dict 的額外狀態

set_submodule(target: str, module: Module, strict: bool = False) None

如果存在,設定由 target 給定的子模組,否則丟擲錯誤。

注意

如果 strict 設定為 False(預設),該方法將替換現有子模組或在父模組存在的情況下建立新子模組。如果 strict 設定為 True,該方法將僅嘗試替換現有子模組,並在子模組不存在時引發錯誤。

例如,假設您有一個 nn.Module A,它看起來像這樣

A(
    (net_b): Module(
        (net_c): Module(
            (conv): Conv2d(3, 3, 3)
        )
        (linear): Linear(3, 3)
    )
)

(圖示了一個 nn.Module AA 包含一個巢狀子模組 net_b,該子模組本身有兩個子模組 net_clinearnet_c 隨後又有一個子模組 conv。)

要用一個新的 Linear 子模組覆蓋 Conv2d,可以呼叫 set_submodule("net_b.net_c.conv", nn.Linear(1, 1)),其中 strict 可以是 TrueFalse

要將一個新的 Conv2d 子模組新增到現有的 net_b 模組中,可以呼叫 set_submodule("net_b.conv", nn.Conv2d(1, 1, 1))

在上面,如果設定 strict=True 並呼叫 set_submodule("net_b.conv", nn.Conv2d(1, 1, 1), strict=True),則會引發 AttributeError,因為 net_b 中不存在名為 conv 的子模組。

引數:
  • target – 要查詢的子模組的完全限定字串名稱。(要指定完全限定字串,請參閱上面的示例。)

  • module – 要設定子模組的物件。

  • strict – 如果為 False,該方法將替換現有子模組或建立新子模組(如果父模組存在)。如果為 True,則該方法只會嘗試替換現有子模組,如果子模組不存在則丟擲錯誤。

丟擲:
  • ValueError – 如果 target 字串為空或 module 不是 nn.Module 的例項。

  • AttributeError – 如果 target 字串路徑中的任何一點解析為一個不存在的屬性名或不是 nn.Module 例項的物件。

share_memory() Self

請參閱 torch.Tensor.share_memory_()

state_dict(*args, destination=None, prefix='', keep_vars=False)

返回一個字典,其中包含對模組整個狀態的引用。

引數和持久緩衝區(例如,執行平均值)都包含在內。鍵是相應的引數和緩衝區名稱。設定為 None 的引數和緩衝區不包含在內。

注意

返回的物件是淺複製。它包含對模組引數和緩衝區的引用。

警告

當前 state_dict() 還接受 destinationprefixkeep_vars 的位置引數,順序為。但是,這正在被棄用,並且在未來的版本中將強制使用關鍵字引數。

警告

請避免使用引數 destination,因為它不是為終端使用者設計的。

引數:
  • destination (dict, optional) – 如果提供,模組的狀態將更新到 dict 中,並返回相同的物件。否則,將建立一個 OrderedDict 並返回。預設為 None

  • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''

  • keep_vars (bool, optional) – 預設情況下,state dict 中返回的 Tensors 會從 autograd 中分離。如果設定為 True,則不會執行分離。預設為 False

返回:

包含模組整體狀態的字典

返回型別:

dict

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
to(*args, **kwargs)

移動和/或轉換引數和緩衝區。

這可以這樣呼叫

to(device=None, dtype=None, non_blocking=False)
to(dtype, non_blocking=False)
to(tensor, non_blocking=False)
to(memory_format=torch.channels_last)

Its signature is similar to torch.Tensor.to(), but only accepts floating point or complex dtypes. In addition, this method will only cast the floating point or complex parameters and buffers to dtype (if given). The integral parameters and buffers will be moved device, if that is given, but with dtypes unchanged. When non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.

有關示例,請參閱下文。

注意

此方法就地修改模組。

引數:
  • device (torch.device) – the desired device of the parameters and buffers in this module – 此模組中引數和緩衝區的目標裝置。

  • dtype (torch.dtype) – the desired floating point or complex dtype of the parameters and buffers in this module – 此模組中引數和緩衝區的目標浮點數或複數 dtype。

  • tensor (torch.Tensor) – Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module – 其 dtype 和 device 是此模組中所有引數和緩衝區的目標 dtype 和 device 的 Tensor。

  • memory_format (torch.memory_format) – the desired memory format for 4D parameters and buffers in this module (keyword only argument) – 此模組中 4D 引數和緩衝區的目標記憶體格式(僅關鍵字引數)。

返回:

self

返回型別:

模組

示例

>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]], dtype=torch.float64)
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16)

>>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
>>> linear.weight
Parameter containing:
tensor([[ 0.3741+0.j,  0.2382+0.j],
        [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
>>> linear(torch.ones(3, 2, dtype=torch.cdouble))
tensor([[0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
to_empty(*, device: Optional[Union[int, str, device]], recurse: bool = True) Self

將引數和緩衝區移動到指定裝置,而不復制儲存。

引數:
  • device (torch.device) – The desired device of the parameters and buffers in this module. – 此模組中引數和緩衝區的目標裝置。

  • recurse (bool) – 是否遞迴地將子模組的引數和緩衝區移動到指定裝置。

返回:

self

返回型別:

模組

train(mode: bool = True) Self

將模組設定為訓練模式。

This has an effect only on certain modules. See the documentation of particular modules for details of their behaviors in training/evaluation mode, i.e., whether they are affected, e.g. Dropout, BatchNorm, etc. – 這隻對某些模組有影響。有關其在訓練/評估模式下的行為的詳細資訊,例如它們是否受影響,請參閱特定模組的文件,例如 DropoutBatchNorm 等。

引數:

mode (bool) – whether to set training mode (True) or evaluation mode (False). Default: True. – 設定訓練模式(True)或評估模式(False)。預設值:True

返回:

self

返回型別:

模組

type(dst_type: Union[dtype, str]) Self

將所有引數和緩衝區轉換為 dst_type

注意

此方法就地修改模組。

引數:

dst_type (type or string) – 目標型別

返回:

self

返回型別:

模組

xpu(device: Optional[Union[device, int]] = None) Self

將所有模型引數和緩衝區移動到 XPU。

這也會使關聯的引數和緩衝區成為不同的物件。因此,如果模組在最佳化時將駐留在 XPU 上,則應在構建最佳化器之前呼叫它。

注意

此方法就地修改模組。

引數:

device (int, optional) – 如果指定,所有引數將複製到該裝置

返回:

self

返回型別:

模組

zero_grad(set_to_none: bool = True) None

重置所有模型引數的梯度。

See similar function under torch.optim.Optimizer for more context. – 有關更多背景資訊,請參閱 torch.optim.Optimizer 下的類似函式。

引數:

set_to_none (bool) – instead of setting to zero, set the grads to None. See torch.optim.Optimizer.zero_grad() for details. – 與其設定為零,不如將 grad 設定為 None。有關詳細資訊,請參閱 torch.optim.Optimizer.zero_grad()

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源