評價此頁

ScriptModule#

class torch.jit.ScriptModule[source]#

Wrapper for C++ torch::jit::Module with methods, attributes, and parameters.

C++ torch::jit::Module 的封裝。 ScriptModule 包含方法、屬性、引數和常量。這些可以與普通 nn.Module 相同的方式訪問。

add_module(name, module)[source]#

將子模組新增到當前模組。

可以使用給定的名稱作為屬性訪問該模組。

引數
  • name (str) – 子模組的名稱。子模組可以透過給定名稱從該模組訪問。

  • module (Module) – 要新增到模組的子模組。

apply(fn)[source]#

fn 遞迴應用於每個子模組(由 .children() 返回)以及自身。

典型用途包括初始化模型的引數(另請參閱 torch.nn.init)。

引數

fn (Module -> None) – 要應用於每個子模組的函式

返回

self

返回型別

模組

示例

>>> @torch.no_grad()
>>> def init_weights(m):
>>>     print(m)
>>>     if type(m) == nn.Linear:
>>>         m.weight.fill_(1.0)
>>>         print(m.weight)
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> net.apply(init_weights)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)
bfloat16()[source]#

將所有浮點引數和緩衝區轉換為 bfloat16 資料型別。

注意

此方法就地修改模組。

返回

self

返回型別

模組

buffers(recurse=True)[source]#

返回模組緩衝區的迭代器。

引數

recurse (bool) – 如果為 True,則會生成此模組及所有子模組的緩衝區。否則,只生成此模組的直接成員緩衝區。

生成

torch.Tensor – 模組緩衝區

返回型別

Iterator[Tensor]

示例

>>> for buf in model.buffers():
>>>     print(type(buf), buf.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
children()[source]#

返回直接子模組的迭代器。

生成

Module – 子模組

返回型別

Iterator[Module]

property code#

Return a pretty-printed representation (as valid Python syntax) of the internal graph for the forward method.

property code_with_constants#

Return a tuple.

Returns a tuple of

[0] a pretty-printed representation (as valid Python syntax) of the internal graph for the forward method. See code. [1] a ConstMap following the CONSTANT.cN format of the output in [0]. The indices in the [0] output are keys to the underlying constant’s values.

compile(*args, **kwargs)[source]#

使用 torch.compile() 編譯此模組的 forward。

此模組的 __call__ 方法已編譯,所有引數將按原樣傳遞給 torch.compile()

有關此函式引數的詳細資訊,請參閱 torch.compile()

cpu()[source]#

將所有模型引數和緩衝區移動到 CPU。

注意

此方法就地修改模組。

返回

self

返回型別

模組

cuda(device=None)[source]#

將所有模型引數和緩衝區移動到 GPU。

這也會使相關的引數和緩衝區成為不同的物件。因此,如果模組在最佳化時將駐留在 GPU 上,則應在構建最佳化器之前呼叫此函式。

注意

此方法就地修改模組。

引數

device (int, optional) – 如果指定,所有引數都將複製到該裝置。

返回

self

返回型別

模組

double()[source]#

將所有浮點引數和緩衝區轉換為 double 資料型別。

注意

此方法就地修改模組。

返回

self

返回型別

模組

eval()[source]#

將模組設定為評估模式。

這僅對某些模組有影響。有關模組在訓練/評估模式下的行為,例如它們是否受影響(如 DropoutBatchNorm 等),請參閱具體模組的文件。

This is equivalent with self.train(False).

請參閱 區域性停用梯度計算,瞭解 .eval() 與一些可能與之混淆的類似機制之間的比較。

返回

self

返回型別

模組

extra_repr()[source]#

返回模組的額外表示。

要列印自定義額外資訊,您應該在自己的模組中重新實現此方法。單行和多行字串均可接受。

返回型別

str

float()[source]#

將所有浮點引數和緩衝區轉換為 float 資料型別。

注意

此方法就地修改模組。

返回

self

返回型別

模組

get_buffer(target)[source]#

返回由 target 給定的緩衝區(如果存在),否則丟擲錯誤。

有關此方法功能的更詳細解釋以及如何正確指定 target,請參閱 get_submodule 的文件字串。

引數

target (str) – 要查詢的緩衝區的完整限定字串名稱。(有關如何指定完整限定字串,請參閱 get_submodule。)

返回

target 引用的緩衝區

返回型別

torch.Tensor

引發

AttributeError – 如果目標字串引用了無效路徑或解析為非緩衝區項。

get_extra_state()[source]#

返回要包含在模組 state_dict 中的任何額外狀態。

Implement this and a corresponding set_extra_state() for your module if you need to store extra state. This function is called when building the module’s state_dict().

注意,為了保證 state_dict 的序列化工作正常,額外狀態應該是可被 pickle 的。我們僅為 Tensors 的序列化提供向後相容性保證;其他物件的序列化形式若發生變化,可能導致向後相容性中斷。

返回

要儲存在模組 state_dict 中的任何額外狀態

返回型別

物件

get_parameter(target)[source]#

如果存在,返回由 target 給定的引數,否則丟擲錯誤。

有關此方法功能的更詳細解釋以及如何正確指定 target,請參閱 get_submodule 的文件字串。

引數

target (str) – 要查詢的引數的完整限定字串名稱。(有關如何指定完整限定字串,請參閱 get_submodule。)

返回

target 引用的引數

返回型別

torch.nn.Parameter

引發

AttributeError – 如果目標字串引用了無效路徑或解析為非 nn.Parameter 項。

get_submodule(target)[source]#

如果存在,返回由 target 給定的子模組,否則丟擲錯誤。

例如,假設您有一個 nn.Module A,它看起來像這樣

A(
    (net_b): Module(
        (net_c): Module(
            (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
        )
        (linear): Linear(in_features=100, out_features=200, bias=True)
    )
)

(圖示了一個 nn.Module AA 包含一個巢狀子模組 net_b,該子模組本身有兩個子模組 net_clinearnet_c 隨後又有一個子模組 conv。)

要檢查是否存在 linear 子模組,可以呼叫 get_submodule("net_b.linear")。要檢查是否存在 conv 子模組,可以呼叫 get_submodule("net_b.net_c.conv")

get_submodule 的執行時複雜度受 target 中模組巢狀深度的限制。與 named_modules 的查詢相比,後者的複雜度是按傳遞模組數量計算的 O(N)。因此,對於簡單地檢查某個子模組是否存在,應始終使用 get_submodule

引數

target (str) – 要查詢的子模組的完整限定字串名稱。(如上例所示,如何指定完整限定字串。)

返回

target 引用的子模組

返回型別

torch.nn.Module

引發

AttributeError – 如果在 target 字串解析出的路徑中的任何一點,(子)路徑解析為一個不存在的屬性名或一個非 nn.Module 例項的物件。

property graph#

Return a string representation of the internal graph for the forward method.

half()[source]#

將所有浮點引數和緩衝區轉換為 half 資料型別。

注意

此方法就地修改模組。

返回

self

返回型別

模組

property inlined_graph#

Return a string representation of the internal graph for the forward method.

This graph will be preprocessed to inline all function and method calls.

ipu(device=None)[source]#

將所有模型引數和緩衝區移動到 IPU。

這也會使關聯的引數和緩衝區成為不同的物件。因此,如果模組在最佳化時將駐留在 IPU 上,則應在構建最佳化器之前呼叫它。

注意

此方法就地修改模組。

引數

device (int, optional) – 如果指定,所有引數都將複製到該裝置。

返回

self

返回型別

模組

load_state_dict(state_dict, strict=True, assign=False)[source]#

Copy parameters and buffers from state_dict into this module and its descendants.

If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

警告

If assign is True the optimizer must be created after the call to load_state_dict unless get_swap_module_params_on_conversion() is True.

引數
  • state_dict (dict) – 包含引數和持久緩衝區的字典。

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

  • assign (bool, optional) – 當設定為 False 時,將保留當前模組中張量的屬性;設定為 True 時,將保留 state dict 中張量的屬性。唯一的例外是 Parameterrequires_grad 欄位,此時將保留模組中的值。預設為 False

返回

  • missing_keys 是一個包含此模組期望但

    在提供的 state_dict 中缺失的任何鍵的字串列表。

  • unexpected_keys 是一個字串列表,包含此模組

    不期望但在提供的 state_dict 中存在的鍵。

返回型別

NamedTuple,包含 missing_keysunexpected_keys 欄位。

注意

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

modules()[source]#

返回網路中所有模組的迭代器。

生成

Module – 網路中的一個模組

返回型別

Iterator[Module]

注意

重複的模組只返回一次。在以下示例中,l 只返回一次。

示例

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.modules()):
...     print(idx, '->', m)

0 -> Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)
1 -> Linear(in_features=2, out_features=2, bias=True)
mtia(device=None)[source]#

將所有模型引數和緩衝區移動到 MTIA。

這也會使關聯的引數和緩衝區成為不同的物件。因此,如果模組在最佳化時將駐留在 MTIA 上,則應在構建最佳化器之前呼叫它。

注意

此方法就地修改模組。

引數

device (int, optional) – 如果指定,所有引數都將複製到該裝置。

返回

self

返回型別

模組

named_buffers(prefix='', recurse=True, remove_duplicate=True)[source]#

返回模組緩衝區上的迭代器,同時生成緩衝區的名稱和緩衝區本身。

引數
  • prefix (str) – 要新增到所有緩衝區名稱的字首。

  • recurse (bool, optional) – 如果為 True,則會生成此模組及所有子模組的緩衝區。否則,只生成此模組的直接成員緩衝區。預設為 True。

  • remove_duplicate (bool, optional) – 是否在結果中刪除重複的緩衝區。預設為 True。

生成

(str, torch.Tensor) – 包含名稱和緩衝區的元組

返回型別

Iterator[tuple[str, torch.Tensor]]

示例

>>> for name, buf in self.named_buffers():
>>>     if name in ['running_var']:
>>>         print(buf.size())
named_children()[source]#

返回對直接子模組的迭代器,生成模組的名稱和模組本身。

生成

(str, Module) – 包含名稱和子模組的元組

返回型別

Iterator[tuple[str, ‘Module’]]

示例

>>> for name, module in model.named_children():
>>>     if name in ['conv4', 'conv5']:
>>>         print(module)
named_modules(memo=None, prefix='', remove_duplicate=True)[source]#

返回網路中所有模組的迭代器,同時生成模組的名稱和模組本身。

引數
  • memo (Optional[set['Module']]) – 用於儲存已新增到結果中的模組集合的備忘錄。

  • prefix (str) – 將新增到模組名稱的字首。

  • remove_duplicate (bool) – 是否從結果中刪除重複的模組例項。

生成

(str, Module) – 名稱和模組的元組

注意

重複的模組只返回一次。在以下示例中,l 只返回一次。

示例

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.named_modules()):
...     print(idx, '->', m)

0 -> ('', Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
))
1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
named_parameters(prefix='', recurse=True, remove_duplicate=True)[source]#

返回模組引數的迭代器,同時生成引數的名稱和引數本身。

引數
  • prefix (str) – 要新增到所有引數名稱的字首。

  • recurse (bool) – 如果為 True,則會生成此模組及所有子模組的引數。否則,只生成此模組的直接成員引數。

  • remove_duplicate (bool, optional) – 是否在結果中刪除重複的引數。預設為 True。

生成

(str, Parameter) – 包含名稱和引數的元組

返回型別

Iterator[tuple[str, torch.nn.parameter.Parameter]]

示例

>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
parameters(recurse=True)[source]#

返回模組引數的迭代器。

這通常傳遞給最佳化器。

引數

recurse (bool) – 如果為 True,則會生成此模組及所有子模組的引數。否則,只生成此模組的直接成員引數。

生成

Parameter – 模組引數

返回型別

Iterator[Parameter]

示例

>>> for param in model.parameters():
>>>     print(type(param), param.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
register_backward_hook(hook)[source]#

在模組上註冊一個反向傳播鉤子。

This function is deprecated in favor of register_full_backward_hook() and the behavior of this function will change in future versions.

返回

一個控制代碼,可用於透過呼叫 handle.remove() 來移除新增的鉤子

返回型別

torch.utils.hooks.RemovableHandle

register_buffer(name, tensor, persistent=True)[source]#

向模組新增一個緩衝區。

This is typically used to register a buffer that should not be considered a model parameter. For example, BatchNorm’s running_mean is not a parameter, but is part of the module’s state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting persistent to False. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module’s state_dict.

可以使用給定名稱作為屬性訪問緩衝區。

引數
  • name (str) – 緩衝區的名稱。緩衝區可以透過給定名稱從該模組訪問。

  • tensor (Tensor or None) – buffer to be registered. If None, then operations that run on buffers, such as cuda, are ignored. If None, the buffer is not included in the module’s state_dict.

  • persistent (bool) – whether the buffer is part of this module’s state_dict.

示例

>>> self.register_buffer('running_mean', torch.zeros(num_features))
register_forward_hook(hook, *, prepend=False, with_kwargs=False, always_call=False)[source]#

在模組上註冊一個前向鉤子。

The hook will be called every time after forward() has computed an output.

If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after forward() is called. The hook should have the following signature

hook(module, args, output) -> None or modified output

如果 with_kwargsTrue,則前向鉤子將接收傳遞給 forward 函式的 kwargs,並需要返回可能已修改的輸出。鉤子應該具有以下簽名

hook(module, args, kwargs, output) -> None or modified output
引數
  • hook (Callable) – 使用者定義的待註冊鉤子。

  • prepend (bool) – If True, the provided hook will be fired before all existing forward hooks on this torch.nn.Module. Otherwise, the provided hook will be fired after all existing forward hooks on this torch.nn.Module. Note that global forward hooks registered with register_module_forward_hook() will fire before all hooks registered by this method. Default: False

  • with_kwargs (bool) – 如果為 True,則 hook 將接收傳遞給 forward 函式的 kwargs。預設為 False

  • always_call (bool) – 如果為 True,則無論呼叫 Module 時是否發生異常,都將執行 hook。預設為 False

返回

一個控制代碼,可用於透過呼叫 handle.remove() 來移除新增的鉤子

返回型別

torch.utils.hooks.RemovableHandle

register_forward_pre_hook(hook, *, prepend=False, with_kwargs=False)[source]#

在模組上註冊一個前向預鉤子。

The hook will be called every time before forward() is invoked.

如果 with_kwargs 為 false 或未指定,則輸入僅包含傳遞給模組的位置引數。關鍵字引數不會傳遞給鉤子,而只會傳遞給 forward。鉤子可以修改輸入。使用者可以返回一個元組或單個修改後的值。我們將把值包裝成一個元組,如果返回的是單個值(除非該值本身就是元組)。鉤子應該具有以下簽名

hook(module, args) -> None or modified input

如果 with_kwargs 為 true,則前向預鉤子將接收傳遞給 forward 函式的 kwargs。如果鉤子修改了輸入,則應該返回 args 和 kwargs。鉤子應該具有以下簽名

hook(module, args, kwargs) -> None or a tuple of modified input and kwargs
引數
  • hook (Callable) – 使用者定義的待註冊鉤子。

  • prepend (bool) – If true, the provided hook will be fired before all existing forward_pre hooks on this torch.nn.Module. Otherwise, the provided hook will be fired after all existing forward_pre hooks on this torch.nn.Module. Note that global forward_pre hooks registered with register_module_forward_pre_hook() will fire before all hooks registered by this method. Default: False

  • with_kwargs (bool) – 如果為 true,則 hook 將接收傳遞給 forward 函式的 kwargs。預設為 False

返回

一個控制代碼,可用於透過呼叫 handle.remove() 來移除新增的鉤子

返回型別

torch.utils.hooks.RemovableHandle

register_full_backward_hook(hook, prepend=False)[source]#

在模組上註冊一個反向傳播鉤子。

每次計算相對於模組的梯度時,將呼叫此鉤子,其觸發規則如下:

  1. 通常,鉤子在計算相對於模組輸入的梯度時觸發。

  2. 如果模組輸入都不需要梯度,則在計算相對於模組輸出的梯度時觸發鉤子。

  3. 如果模組輸出都不需要梯度,則鉤子將不觸發。

鉤子應具有以下簽名

hook(module, grad_input, grad_output) -> tuple(Tensor) or None

The grad_input and grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of grad_input in subsequent computations. grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in grad_input and grad_output will be None for all non-Tensor arguments.

由於技術原因,當此鉤子應用於模組時,其前向函式將接收傳遞給模組的每個張量的檢視。類似地,呼叫者將接收模組前向函式返回的每個張量的檢視。

警告

使用反向傳播鉤子時不允許就地修改輸入或輸出,否則將引發錯誤。

引數
  • hook (Callable) – 要註冊的使用者定義鉤子。

  • prepend (bool) – If true, the provided hook will be fired before all existing backward hooks on this torch.nn.Module. Otherwise, the provided hook will be fired after all existing backward hooks on this torch.nn.Module. Note that global backward hooks registered with register_module_full_backward_hook() will fire before all hooks registered by this method.

返回

一個控制代碼,可用於透過呼叫 handle.remove() 來移除新增的鉤子

返回型別

torch.utils.hooks.RemovableHandle

register_full_backward_pre_hook(hook, prepend=False)[source]#

在模組上註冊一個反向預鉤子。

每次計算模組的梯度時,將呼叫此鉤子。鉤子應具有以下簽名

hook(module, grad_output) -> tuple[Tensor] or None

grad_output 是一個元組。鉤子不應修改其引數,但可以選擇返回一個新的輸出梯度,該梯度將取代 grad_output 用於後續計算。對於所有非 Tensor 引數,grad_output 中的條目將為 None

由於技術原因,當此鉤子應用於模組時,其前向函式將接收傳遞給模組的每個張量的檢視。類似地,呼叫者將接收模組前向函式返回的每個張量的檢視。

警告

使用反向傳播鉤子時不允許就地修改輸入,否則將引發錯誤。

引數
  • hook (Callable) – 要註冊的使用者定義鉤子。

  • prepend (bool) – If true, the provided hook will be fired before all existing backward_pre hooks on this torch.nn.Module. Otherwise, the provided hook will be fired after all existing backward_pre hooks on this torch.nn.Module. Note that global backward_pre hooks registered with register_module_full_backward_pre_hook() will fire before all hooks registered by this method.

返回

一個控制代碼,可用於透過呼叫 handle.remove() 來移除新增的鉤子

返回型別

torch.utils.hooks.RemovableHandle

register_load_state_dict_post_hook(hook)[source]#

註冊一個後鉤子,用於在模組的 load_state_dict() 被呼叫後執行。

它應該具有以下簽名:

hook(module, incompatible_keys) -> None

The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys. missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys.

如果需要,可以就地修改給定的 incompatible_keys。

Note that the checks performed when calling load_state_dict() with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys, as expected. Additions to either set of keys will result in an error being thrown when strict=True, and clearing out both missing and unexpected keys will avoid an error.

返回

一個控制代碼,可用於透過呼叫 handle.remove() 來移除新增的鉤子

返回型別

torch.utils.hooks.RemovableHandle

register_load_state_dict_pre_hook(hook)[source]#

註冊一個預鉤子,用於在模組的 load_state_dict() 被呼叫之前執行。

它應該具有以下簽名:

hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950

引數

hook (Callable) – 在載入狀態字典之前將呼叫的可呼叫鉤子。

register_module(name, module)[source]#

Alias for add_module().

register_parameter(name, param)[source]#

向模組新增一個引數。

可以使用給定名稱作為屬性訪問該引數。

引數
  • name (str) – 引數的名稱。引數可以透過給定名稱從該模組訪問。

  • param (Parameter or None) – parameter to be added to the module. If None, then operations that run on parameters, such as cuda, are ignored. If None, the parameter is not included in the module’s state_dict.

register_state_dict_post_hook(hook)[source]#

Register a post-hook for the state_dict() method.

它應該具有以下簽名:

hook(module, state_dict, prefix, local_metadata) -> None

註冊的鉤子可以就地修改 state_dict

register_state_dict_pre_hook(hook)[source]#

Register a pre-hook for the state_dict() method.

它應該具有以下簽名:

hook(module, prefix, keep_vars) -> None

註冊的鉤子可用於在進行 state_dict 呼叫之前執行預處理。

requires_grad_(requires_grad=True)[source]#

更改自動梯度是否應記錄此模組中引數的操作。

此方法就地設定引數的 requires_grad 屬性。

此方法有助於凍結模組的一部分以進行微調或單獨訓練模型的一部分(例如,GAN 訓練)。

請參閱 區域性停用梯度計算,瞭解 .requires_grad_() 與一些可能與之混淆的類似機制之間的比較。

引數

requires_grad (bool) – 是否應為此模組中的引數啟用自動求導。預設為 True

返回

self

返回型別

模組

save(f, **kwargs)[source]#

Save with a file-like object.

save(f, _extra_files={})

See torch.jit.save which accepts a file-like object. This function, torch.save(), converts the object to a string, treating it as a path. DO NOT confuse these two functions when it comes to the ‘f’ parameter functionality.

set_extra_state(state)[source]#

設定載入的 state_dict 中包含的額外狀態。

This function is called from load_state_dict() to handle any extra state found within the state_dict. Implement this function and a corresponding get_extra_state() for your module if you need to store extra state within its state_dict.

引數

state (dict) – 來自 state_dict 的額外狀態。

set_submodule(target, module, strict=False)[source]#

如果存在,設定由 target 給定的子模組,否則丟擲錯誤。

注意

如果 strict 設定為 False(預設),該方法將替換現有子模組或在父模組存在的情況下建立新子模組。如果 strict 設定為 True,該方法將僅嘗試替換現有子模組,並在子模組不存在時引發錯誤。

例如,假設您有一個 nn.Module A,它看起來像這樣

A(
    (net_b): Module(
        (net_c): Module(
            (conv): Conv2d(3, 3, 3)
        )
        (linear): Linear(3, 3)
    )
)

(圖示了一個 nn.Module AA 包含一個巢狀子模組 net_b,該子模組本身有兩個子模組 net_clinearnet_c 隨後又有一個子模組 conv。)

要用一個新的 Linear 子模組覆蓋 Conv2d,可以呼叫 set_submodule("net_b.net_c.conv", nn.Linear(1, 1)),其中 strict 可以是 TrueFalse

要將一個新的 Conv2d 子模組新增到現有的 net_b 模組中,可以呼叫 set_submodule("net_b.conv", nn.Conv2d(1, 1, 1))

在上面,如果設定 strict=True 並呼叫 set_submodule("net_b.conv", nn.Conv2d(1, 1, 1), strict=True),則會引發 AttributeError,因為 net_b 中不存在名為 conv 的子模組。

引數
  • target (str) – 要查詢的子模組的完整限定字串名稱。(如上例所示,如何指定完整限定字串。)

  • module (Module) – The module to set the submodule to.

  • strict (bool) – 如果為 False,則該方法將替換現有子模組或在父模組存在的情況下建立新子模組。如果為 True,則該方法將僅嘗試替換現有子模組,並在子模組不存在時引發錯誤。

引發
  • ValueError – 如果 target 字串為空,或者 module 不是 nn.Module 的例項。

  • AttributeError – 如果在 target 字串解析出的路徑中的任何一點,(子)路徑解析為一個不存在的屬性名或一個非 nn.Module 例項的物件。

share_memory()[source]#

請參閱 torch.Tensor.share_memory_()

返回型別

自我

state_dict(*args, destination=None, prefix='', keep_vars=False)[source]#

返回一個字典,其中包含對模組整個狀態的引用。

引數和持久緩衝區(例如,執行平均值)都包含在內。鍵是相應的引數和緩衝區名稱。設定為 None 的引數和緩衝區不包含在內。

注意

返回的物件是淺複製。它包含對模組引數和緩衝區的引用。

警告

當前 state_dict() 還接受 destinationprefixkeep_vars 的位置引數,順序為。但是,這正在被棄用,並且在未來的版本中將強制使用關鍵字引數。

警告

請避免使用引數 destination,因為它不是為終端使用者設計的。

引數
  • destination (dict, optional) – 如果提供,模組的狀態將被更新到字典中,並返回同一個物件。否則,將建立一個 OrderedDict 並返回。預設為 None

  • prefix (str, optional) – 新增到引數和緩衝區名稱的字首,用於構成 state_dict 中的鍵。預設為 ''

  • keep_vars (bool, optional) – 預設情況下,state_dict 中返回的 Tensor s 會從自動求導中分離。如果設定為 True,則不會進行分離。預設為 False

返回

包含模組整體狀態的字典

返回型別

dict

示例

>>> module.state_dict().keys()
['bias', 'weight']
to(*args, **kwargs)[source]#

移動和/或轉換引數和緩衝區。

這可以這樣呼叫

to(device=None, dtype=None, non_blocking=False)[source]
to(dtype, non_blocking=False)[原始碼]
to(tensor, non_blocking=False)[原始碼]
to(memory_format=torch.channels_last)[原始碼]

Its signature is similar to torch.Tensor.to(), but only accepts floating point or complex dtypes. In addition, this method will only cast the floating point or complex parameters and buffers to dtype (if given). The integral parameters and buffers will be moved device, if that is given, but with dtypes unchanged. When non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.

有關示例,請參閱下文。

注意

此方法就地修改模組。

引數
  • device (torch.device) – 此模組中的引數和緩衝區的目標裝置

  • dtype (torch.dtype) – 此模組中的引數和緩衝區的目標浮點數或複數 dtype

  • tensor (torch.Tensor) – 張量,其 dtype 和裝置是此模組中所有引數和緩衝區的所需 dtype 和裝置

  • memory_format (torch.memory_format) – 此模組中 4D 引數和緩衝區的目標記憶體格式(僅限關鍵字引數)

返回

self

返回型別

模組

示例

>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]], dtype=torch.float64)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16)

>>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
>>> linear.weight
Parameter containing:
tensor([[ 0.3741+0.j,  0.2382+0.j],
        [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
>>> linear(torch.ones(3, 2, dtype=torch.cdouble))
tensor([[0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
to_empty(*, device, recurse=True)[source]#

將引數和緩衝區移動到指定裝置,而不復制儲存。

引數
  • device (torch.device) – 此模組中的引數和緩衝區的目標裝置。

  • recurse (bool) – 是否遞迴地將子模組的引數和緩衝區移動到指定裝置。

返回

self

返回型別

模組

train(mode=True)[source]#

將模組設定為訓練模式。

This has an effect only on certain modules. See the documentation of particular modules for details of their behaviors in training/evaluation mode, i.e., whether they are affected, e.g. Dropout, BatchNorm, etc. – 這隻對某些模組有影響。有關其在訓練/評估模式下的行為的詳細資訊,例如它們是否受影響,請參閱特定模組的文件,例如 DropoutBatchNorm 等。

引數

mode (bool) – 設定訓練模式(True)還是評估模式(False)。預設值:True

返回

self

返回型別

模組

type(dst_type)[source]#

將所有引數和緩衝區轉換為 dst_type

注意

此方法就地修改模組。

引數

dst_type (typestring) – 目標型別

返回

self

返回型別

模組

xpu(device=None)[source]#

將所有模型引數和緩衝區移動到 XPU。

這也會使關聯的引數和緩衝區成為不同的物件。因此,如果模組在最佳化時將駐留在 XPU 上,則應在構建最佳化器之前呼叫它。

注意

此方法就地修改模組。

引數

device (int, optional) – 如果指定,所有引數都將複製到該裝置。

返回

self

返回型別

模組

zero_grad(set_to_none=True)[source]#

重置所有模型引數的梯度。

請參閱 torch.optim.Optimizer 下的類似函式以獲取更多上下文。

引數

set_to_none (bool) – 不設定為零,而是將梯度設定為 None。有關詳細資訊,請參閱 torch.optim.Optimizer.zero_grad()