ScriptModule#
- class torch.jit.ScriptModule[source]#
Wrapper for C++ torch::jit::Module with methods, attributes, and parameters.
C++
torch::jit::Module的封裝。ScriptModule包含方法、屬性、引數和常量。這些可以與普通nn.Module相同的方式訪問。- apply(fn)[source]#
將
fn遞迴應用於每個子模組(由.children()返回)以及自身。典型用途包括初始化模型的引數(另請參閱 torch.nn.init)。
- 引數
fn (
Module-> None) – 要應用於每個子模組的函式- 返回
self
- 返回型別
示例
>>> @torch.no_grad() >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: >>> m.weight.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Linear(in_features=2, out_features=2, bias=True) Parameter containing: tensor([[1., 1.], [1., 1.]], requires_grad=True) Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )
- buffers(recurse=True)[source]#
返回模組緩衝區的迭代器。
- 引數
recurse (bool) – 如果為 True,則會生成此模組及所有子模組的緩衝區。否則,只生成此模組的直接成員緩衝區。
- 生成
torch.Tensor – 模組緩衝區
- 返回型別
示例
>>> for buf in model.buffers(): >>> print(type(buf), buf.size()) <class 'torch.Tensor'> (20L,) <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
- property code#
Return a pretty-printed representation (as valid Python syntax) of the internal graph for the
forwardmethod.
- property code_with_constants#
Return a tuple.
Returns a tuple of
[0] a pretty-printed representation (as valid Python syntax) of the internal graph for the
forwardmethod. See code. [1] a ConstMap following the CONSTANT.cN format of the output in [0]. The indices in the [0] output are keys to the underlying constant’s values.
- compile(*args, **kwargs)[source]#
使用
torch.compile()編譯此模組的 forward。此模組的 __call__ 方法已編譯,所有引數將按原樣傳遞給
torch.compile()。有關此函式引數的詳細資訊,請參閱
torch.compile()。
- cuda(device=None)[source]#
將所有模型引數和緩衝區移動到 GPU。
這也會使相關的引數和緩衝區成為不同的物件。因此,如果模組在最佳化時將駐留在 GPU 上,則應在構建最佳化器之前呼叫此函式。
注意
此方法就地修改模組。
- eval()[source]#
將模組設定為評估模式。
這僅對某些模組有影響。有關模組在訓練/評估模式下的行為,例如它們是否受影響(如
Dropout、BatchNorm等),請參閱具體模組的文件。This is equivalent with
self.train(False).請參閱 區域性停用梯度計算,瞭解 .eval() 與一些可能與之混淆的類似機制之間的比較。
- 返回
self
- 返回型別
- get_buffer(target)[source]#
返回由
target給定的緩衝區(如果存在),否則丟擲錯誤。有關此方法功能的更詳細解釋以及如何正確指定
target,請參閱get_submodule的文件字串。- 引數
target (str) – 要查詢的緩衝區的完整限定字串名稱。(有關如何指定完整限定字串,請參閱
get_submodule。)- 返回
由
target引用的緩衝區- 返回型別
- 引發
AttributeError – 如果目標字串引用了無效路徑或解析為非緩衝區項。
- get_extra_state()[source]#
返回要包含在模組 state_dict 中的任何額外狀態。
Implement this and a corresponding
set_extra_state()for your module if you need to store extra state. This function is called when building the module’s state_dict().注意,為了保證 state_dict 的序列化工作正常,額外狀態應該是可被 pickle 的。我們僅為 Tensors 的序列化提供向後相容性保證;其他物件的序列化形式若發生變化,可能導致向後相容性中斷。
- 返回
要儲存在模組 state_dict 中的任何額外狀態
- 返回型別
- get_parameter(target)[source]#
如果存在,返回由
target給定的引數,否則丟擲錯誤。有關此方法功能的更詳細解釋以及如何正確指定
target,請參閱get_submodule的文件字串。- 引數
target (str) – 要查詢的引數的完整限定字串名稱。(有關如何指定完整限定字串,請參閱
get_submodule。)- 返回
由
target引用的引數- 返回型別
torch.nn.Parameter
- 引發
AttributeError – 如果目標字串引用了無效路徑或解析為非
nn.Parameter項。
- get_submodule(target)[source]#
如果存在,返回由
target給定的子模組,否則丟擲錯誤。例如,假設您有一個
nn.ModuleA,它看起來像這樣A( (net_b): Module( (net_c): Module( (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2)) ) (linear): Linear(in_features=100, out_features=200, bias=True) ) )(圖示了一個
nn.ModuleA。A包含一個巢狀子模組net_b,該子模組本身有兩個子模組net_c和linear。net_c隨後又有一個子模組conv。)要檢查是否存在
linear子模組,可以呼叫get_submodule("net_b.linear")。要檢查是否存在conv子模組,可以呼叫get_submodule("net_b.net_c.conv")。get_submodule的執行時複雜度受target中模組巢狀深度的限制。與named_modules的查詢相比,後者的複雜度是按傳遞模組數量計算的 O(N)。因此,對於簡單地檢查某個子模組是否存在,應始終使用get_submodule。- 引數
target (str) – 要查詢的子模組的完整限定字串名稱。(如上例所示,如何指定完整限定字串。)
- 返回
由
target引用的子模組- 返回型別
- 引發
AttributeError – 如果在
target字串解析出的路徑中的任何一點,(子)路徑解析為一個不存在的屬性名或一個非nn.Module例項的物件。
- property graph#
Return a string representation of the internal graph for the
forwardmethod.
- property inlined_graph#
Return a string representation of the internal graph for the
forwardmethod.This graph will be preprocessed to inline all function and method calls.
- ipu(device=None)[source]#
將所有模型引數和緩衝區移動到 IPU。
這也會使關聯的引數和緩衝區成為不同的物件。因此,如果模組在最佳化時將駐留在 IPU 上,則應在構建最佳化器之前呼叫它。
注意
此方法就地修改模組。
- load_state_dict(state_dict, strict=True, assign=False)[source]#
Copy parameters and buffers from
state_dictinto this module and its descendants.If
strictisTrue, then the keys ofstate_dictmust exactly match the keys returned by this module’sstate_dict()function.警告
If
assignisTruethe optimizer must be created after the call toload_state_dictunlessget_swap_module_params_on_conversion()isTrue.- 引數
state_dict (dict) – 包含引數和持久緩衝區的字典。
strict (bool, optional) – whether to strictly enforce that the keys in
state_dictmatch the keys returned by this module’sstate_dict()function. Default:Trueassign (bool, optional) – 當設定為
False時,將保留當前模組中張量的屬性;設定為True時,將保留 state dict 中張量的屬性。唯一的例外是Parameter的requires_grad欄位,此時將保留模組中的值。預設為False。
- 返回
missing_keys是一個包含此模組期望但在提供的
state_dict中缺失的任何鍵的字串列表。
unexpected_keys是一個字串列表,包含此模組不期望但在提供的
state_dict中存在的鍵。
- 返回型別
NamedTuple,包含missing_keys和unexpected_keys欄位。
注意
If a parameter or buffer is registered as
Noneand its corresponding key exists instate_dict,load_state_dict()will raise aRuntimeError.
- modules()[source]#
返回網路中所有模組的迭代器。
注意
重複的模組只返回一次。在以下示例中,
l只返回一次。示例
>>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.modules()): ... print(idx, '->', m) 0 -> Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) ) 1 -> Linear(in_features=2, out_features=2, bias=True)
- mtia(device=None)[source]#
將所有模型引數和緩衝區移動到 MTIA。
這也會使關聯的引數和緩衝區成為不同的物件。因此,如果模組在最佳化時將駐留在 MTIA 上,則應在構建最佳化器之前呼叫它。
注意
此方法就地修改模組。
- named_buffers(prefix='', recurse=True, remove_duplicate=True)[source]#
返回模組緩衝區上的迭代器,同時生成緩衝區的名稱和緩衝區本身。
- 引數
- 生成
(str, torch.Tensor) – 包含名稱和緩衝區的元組
- 返回型別
示例
>>> for name, buf in self.named_buffers(): >>> if name in ['running_var']: >>> print(buf.size())
- named_children()[source]#
返回對直接子模組的迭代器,生成模組的名稱和模組本身。
示例
>>> for name, module in model.named_children(): >>> if name in ['conv4', 'conv5']: >>> print(module)
- named_modules(memo=None, prefix='', remove_duplicate=True)[source]#
返回網路中所有模組的迭代器,同時生成模組的名稱和模組本身。
- 引數
- 生成
(str, Module) – 名稱和模組的元組
注意
重複的模組只返回一次。在以下示例中,
l只返回一次。示例
>>> l = nn.Linear(2, 2) >>> net = nn.Sequential(l, l) >>> for idx, m in enumerate(net.named_modules()): ... print(idx, '->', m) 0 -> ('', Sequential( (0): Linear(in_features=2, out_features=2, bias=True) (1): Linear(in_features=2, out_features=2, bias=True) )) 1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
- named_parameters(prefix='', recurse=True, remove_duplicate=True)[source]#
返回模組引數的迭代器,同時生成引數的名稱和引數本身。
- 引數
- 生成
(str, Parameter) – 包含名稱和引數的元組
- 返回型別
示例
>>> for name, param in self.named_parameters(): >>> if name in ['bias']: >>> print(param.size())
- parameters(recurse=True)[source]#
返回模組引數的迭代器。
這通常傳遞給最佳化器。
- 引數
recurse (bool) – 如果為 True,則會生成此模組及所有子模組的引數。否則,只生成此模組的直接成員引數。
- 生成
Parameter – 模組引數
- 返回型別
示例
>>> for param in model.parameters(): >>> print(type(param), param.size()) <class 'torch.Tensor'> (20L,) <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
- register_backward_hook(hook)[source]#
在模組上註冊一個反向傳播鉤子。
This function is deprecated in favor of
register_full_backward_hook()and the behavior of this function will change in future versions.- 返回
一個控制代碼,可用於透過呼叫
handle.remove()來移除新增的鉤子- 返回型別
torch.utils.hooks.RemovableHandle
- register_buffer(name, tensor, persistent=True)[source]#
向模組新增一個緩衝區。
This is typically used to register a buffer that should not be considered a model parameter. For example, BatchNorm’s
running_meanis not a parameter, but is part of the module’s state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by settingpersistenttoFalse. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module’sstate_dict.可以使用給定名稱作為屬性訪問緩衝區。
- 引數
name (str) – 緩衝區的名稱。緩衝區可以透過給定名稱從該模組訪問。
tensor (Tensor or None) – buffer to be registered. If
None, then operations that run on buffers, such ascuda, are ignored. IfNone, the buffer is not included in the module’sstate_dict.persistent (bool) – whether the buffer is part of this module’s
state_dict.
示例
>>> self.register_buffer('running_mean', torch.zeros(num_features))
- register_forward_hook(hook, *, prepend=False, with_kwargs=False, always_call=False)[source]#
在模組上註冊一個前向鉤子。
The hook will be called every time after
forward()has computed an output.If
with_kwargsisFalseor not specified, the input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to theforward. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called afterforward()is called. The hook should have the following signaturehook(module, args, output) -> None or modified output
如果
with_kwargs為True,則前向鉤子將接收傳遞給 forward 函式的kwargs,並需要返回可能已修改的輸出。鉤子應該具有以下簽名hook(module, args, kwargs, output) -> None or modified output
- 引數
hook (Callable) – 使用者定義的待註冊鉤子。
prepend (bool) – If
True, the providedhookwill be fired before all existingforwardhooks on thistorch.nn.Module. Otherwise, the providedhookwill be fired after all existingforwardhooks on thistorch.nn.Module. Note that globalforwardhooks registered withregister_module_forward_hook()will fire before all hooks registered by this method. Default:Falsewith_kwargs (bool) – 如果為
True,則hook將接收傳遞給 forward 函式的 kwargs。預設為False。always_call (bool) – 如果為
True,則無論呼叫 Module 時是否發生異常,都將執行hook。預設為False。
- 返回
一個控制代碼,可用於透過呼叫
handle.remove()來移除新增的鉤子- 返回型別
torch.utils.hooks.RemovableHandle
- register_forward_pre_hook(hook, *, prepend=False, with_kwargs=False)[source]#
在模組上註冊一個前向預鉤子。
The hook will be called every time before
forward()is invoked.如果
with_kwargs為 false 或未指定,則輸入僅包含傳遞給模組的位置引數。關鍵字引數不會傳遞給鉤子,而只會傳遞給forward。鉤子可以修改輸入。使用者可以返回一個元組或單個修改後的值。我們將把值包裝成一個元組,如果返回的是單個值(除非該值本身就是元組)。鉤子應該具有以下簽名hook(module, args) -> None or modified input
如果
with_kwargs為 true,則前向預鉤子將接收傳遞給 forward 函式的 kwargs。如果鉤子修改了輸入,則應該返回 args 和 kwargs。鉤子應該具有以下簽名hook(module, args, kwargs) -> None or a tuple of modified input and kwargs
- 引數
hook (Callable) – 使用者定義的待註冊鉤子。
prepend (bool) – If true, the provided
hookwill be fired before all existingforward_prehooks on thistorch.nn.Module. Otherwise, the providedhookwill be fired after all existingforward_prehooks on thistorch.nn.Module. Note that globalforward_prehooks registered withregister_module_forward_pre_hook()will fire before all hooks registered by this method. Default:Falsewith_kwargs (bool) – 如果為 true,則
hook將接收傳遞給 forward 函式的 kwargs。預設為False。
- 返回
一個控制代碼,可用於透過呼叫
handle.remove()來移除新增的鉤子- 返回型別
torch.utils.hooks.RemovableHandle
- register_full_backward_hook(hook, prepend=False)[source]#
在模組上註冊一個反向傳播鉤子。
每次計算相對於模組的梯度時,將呼叫此鉤子,其觸發規則如下:
通常,鉤子在計算相對於模組輸入的梯度時觸發。
如果模組輸入都不需要梯度,則在計算相對於模組輸出的梯度時觸發鉤子。
如果模組輸出都不需要梯度,則鉤子將不觸發。
鉤子應具有以下簽名
hook(module, grad_input, grad_output) -> tuple(Tensor) or None
The
grad_inputandgrad_outputare tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place ofgrad_inputin subsequent computations.grad_inputwill only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries ingrad_inputandgrad_outputwill beNonefor all non-Tensor arguments.由於技術原因,當此鉤子應用於模組時,其前向函式將接收傳遞給模組的每個張量的檢視。類似地,呼叫者將接收模組前向函式返回的每個張量的檢視。
警告
使用反向傳播鉤子時不允許就地修改輸入或輸出,否則將引發錯誤。
- 引數
hook (Callable) – 要註冊的使用者定義鉤子。
prepend (bool) – If true, the provided
hookwill be fired before all existingbackwardhooks on thistorch.nn.Module. Otherwise, the providedhookwill be fired after all existingbackwardhooks on thistorch.nn.Module. Note that globalbackwardhooks registered withregister_module_full_backward_hook()will fire before all hooks registered by this method.
- 返回
一個控制代碼,可用於透過呼叫
handle.remove()來移除新增的鉤子- 返回型別
torch.utils.hooks.RemovableHandle
- register_full_backward_pre_hook(hook, prepend=False)[source]#
在模組上註冊一個反向預鉤子。
每次計算模組的梯度時,將呼叫此鉤子。鉤子應具有以下簽名
hook(module, grad_output) -> tuple[Tensor] or None
grad_output是一個元組。鉤子不應修改其引數,但可以選擇返回一個新的輸出梯度,該梯度將取代grad_output用於後續計算。對於所有非 Tensor 引數,grad_output中的條目將為None。由於技術原因,當此鉤子應用於模組時,其前向函式將接收傳遞給模組的每個張量的檢視。類似地,呼叫者將接收模組前向函式返回的每個張量的檢視。
警告
使用反向傳播鉤子時不允許就地修改輸入,否則將引發錯誤。
- 引數
hook (Callable) – 要註冊的使用者定義鉤子。
prepend (bool) – If true, the provided
hookwill be fired before all existingbackward_prehooks on thistorch.nn.Module. Otherwise, the providedhookwill be fired after all existingbackward_prehooks on thistorch.nn.Module. Note that globalbackward_prehooks registered withregister_module_full_backward_pre_hook()will fire before all hooks registered by this method.
- 返回
一個控制代碼,可用於透過呼叫
handle.remove()來移除新增的鉤子- 返回型別
torch.utils.hooks.RemovableHandle
- register_load_state_dict_post_hook(hook)[source]#
註冊一個後鉤子,用於在模組的
load_state_dict()被呼叫後執行。- 它應該具有以下簽名:
hook(module, incompatible_keys) -> None
The
moduleargument is the current module that this hook is registered on, and theincompatible_keysargument is aNamedTupleconsisting of attributesmissing_keysandunexpected_keys.missing_keysis alistofstrcontaining the missing keys andunexpected_keysis alistofstrcontaining the unexpected keys.如果需要,可以就地修改給定的 incompatible_keys。
Note that the checks performed when calling
load_state_dict()withstrict=Trueare affected by modifications the hook makes tomissing_keysorunexpected_keys, as expected. Additions to either set of keys will result in an error being thrown whenstrict=True, and clearing out both missing and unexpected keys will avoid an error.- 返回
一個控制代碼,可用於透過呼叫
handle.remove()來移除新增的鉤子- 返回型別
torch.utils.hooks.RemovableHandle
- register_load_state_dict_pre_hook(hook)[source]#
註冊一個預鉤子,用於在模組的
load_state_dict()被呼叫之前執行。- 它應該具有以下簽名:
hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950
- 引數
hook (Callable) – 在載入狀態字典之前將呼叫的可呼叫鉤子。
- register_module(name, module)[source]#
Alias for
add_module().
- register_parameter(name, param)[source]#
向模組新增一個引數。
可以使用給定名稱作為屬性訪問該引數。
- 引數
name (str) – 引數的名稱。引數可以透過給定名稱從該模組訪問。
param (Parameter or None) – parameter to be added to the module. If
None, then operations that run on parameters, such ascuda, are ignored. IfNone, the parameter is not included in the module’sstate_dict.
- register_state_dict_post_hook(hook)[source]#
Register a post-hook for the
state_dict()method.- 它應該具有以下簽名:
hook(module, state_dict, prefix, local_metadata) -> None
註冊的鉤子可以就地修改
state_dict。
- register_state_dict_pre_hook(hook)[source]#
Register a pre-hook for the
state_dict()method.- 它應該具有以下簽名:
hook(module, prefix, keep_vars) -> None
註冊的鉤子可用於在進行
state_dict呼叫之前執行預處理。
- requires_grad_(requires_grad=True)[source]#
更改自動梯度是否應記錄此模組中引數的操作。
此方法就地設定引數的
requires_grad屬性。此方法有助於凍結模組的一部分以進行微調或單獨訓練模型的一部分(例如,GAN 訓練)。
請參閱 區域性停用梯度計算,瞭解 .requires_grad_() 與一些可能與之混淆的類似機制之間的比較。
- save(f, **kwargs)[source]#
Save with a file-like object.
save(f, _extra_files={})
See
torch.jit.savewhich accepts a file-like object. This function, torch.save(), converts the object to a string, treating it as a path. DO NOT confuse these two functions when it comes to the ‘f’ parameter functionality.
- set_extra_state(state)[source]#
設定載入的 state_dict 中包含的額外狀態。
This function is called from
load_state_dict()to handle any extra state found within the state_dict. Implement this function and a correspondingget_extra_state()for your module if you need to store extra state within its state_dict.- 引數
state (dict) – 來自 state_dict 的額外狀態。
- set_submodule(target, module, strict=False)[source]#
如果存在,設定由
target給定的子模組,否則丟擲錯誤。注意
如果
strict設定為False(預設),該方法將替換現有子模組或在父模組存在的情況下建立新子模組。如果strict設定為True,該方法將僅嘗試替換現有子模組,並在子模組不存在時引發錯誤。例如,假設您有一個
nn.ModuleA,它看起來像這樣A( (net_b): Module( (net_c): Module( (conv): Conv2d(3, 3, 3) ) (linear): Linear(3, 3) ) )(圖示了一個
nn.ModuleA。A包含一個巢狀子模組net_b,該子模組本身有兩個子模組net_c和linear。net_c隨後又有一個子模組conv。)要用一個新的
Linear子模組覆蓋Conv2d,可以呼叫set_submodule("net_b.net_c.conv", nn.Linear(1, 1)),其中strict可以是True或False。要將一個新的
Conv2d子模組新增到現有的net_b模組中,可以呼叫set_submodule("net_b.conv", nn.Conv2d(1, 1, 1))。在上面,如果設定
strict=True並呼叫set_submodule("net_b.conv", nn.Conv2d(1, 1, 1), strict=True),則會引發 AttributeError,因為net_b中不存在名為conv的子模組。- 引數
- 引發
ValueError – 如果
target字串為空,或者module不是nn.Module的例項。AttributeError – 如果在
target字串解析出的路徑中的任何一點,(子)路徑解析為一個不存在的屬性名或一個非nn.Module例項的物件。
請參閱
torch.Tensor.share_memory_()。- 返回型別
自我
- state_dict(*args, destination=None, prefix='', keep_vars=False)[source]#
返回一個字典,其中包含對模組整個狀態的引用。
引數和持久緩衝區(例如,執行平均值)都包含在內。鍵是相應的引數和緩衝區名稱。設定為
None的引數和緩衝區不包含在內。注意
返回的物件是淺複製。它包含對模組引數和緩衝區的引用。
警告
當前
state_dict()還接受destination、prefix和keep_vars的位置引數,順序為。但是,這正在被棄用,並且在未來的版本中將強制使用關鍵字引數。警告
請避免使用引數
destination,因為它不是為終端使用者設計的。- 引數
- 返回
包含模組整體狀態的字典
- 返回型別
示例
>>> module.state_dict().keys() ['bias', 'weight']
- to(*args, **kwargs)[source]#
移動和/或轉換引數和緩衝區。
這可以這樣呼叫
- to(device=None, dtype=None, non_blocking=False)[source]
- to(dtype, non_blocking=False)[原始碼]
- to(tensor, non_blocking=False)[原始碼]
- to(memory_format=torch.channels_last)[原始碼]
Its signature is similar to
torch.Tensor.to(), but only accepts floating point or complexdtypes. In addition, this method will only cast the floating point or complex parameters and buffers todtype(if given). The integral parameters and buffers will be moveddevice, if that is given, but with dtypes unchanged. Whennon_blockingis set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.有關示例,請參閱下文。
注意
此方法就地修改模組。
- 引數
device (
torch.device) – 此模組中的引數和緩衝區的目標裝置dtype (
torch.dtype) – 此模組中的引數和緩衝區的目標浮點數或複數dtypetensor (torch.Tensor) – 張量,其 dtype 和裝置是此模組中所有引數和緩衝區的所需 dtype 和裝置
memory_format (
torch.memory_format) – 此模組中 4D 引數和緩衝區的目標記憶體格式(僅限關鍵字引數)
- 返回
self
- 返回型別
示例
>>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> gpu1 = torch.device("cuda:1") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device("cpu") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
- to_empty(*, device, recurse=True)[source]#
將引數和緩衝區移動到指定裝置,而不復制儲存。
- 引數
device (
torch.device) – 此模組中的引數和緩衝區的目標裝置。recurse (bool) – 是否遞迴地將子模組的引數和緩衝區移動到指定裝置。
- 返回
self
- 返回型別
- train(mode=True)[source]#
將模組設定為訓練模式。
This has an effect only on certain modules. See the documentation of particular modules for details of their behaviors in training/evaluation mode, i.e., whether they are affected, e.g.
Dropout,BatchNorm, etc. – 這隻對某些模組有影響。有關其在訓練/評估模式下的行為的詳細資訊,例如它們是否受影響,請參閱特定模組的文件,例如Dropout、BatchNorm等。
- xpu(device=None)[source]#
將所有模型引數和緩衝區移動到 XPU。
這也會使關聯的引數和緩衝區成為不同的物件。因此,如果模組在最佳化時將駐留在 XPU 上,則應在構建最佳化器之前呼叫它。
注意
此方法就地修改模組。
- zero_grad(set_to_none=True)[source]#
重置所有模型引數的梯度。
請參閱
torch.optim.Optimizer下的類似函式以獲取更多上下文。- 引數
set_to_none (bool) – 不設定為零,而是將梯度設定為 None。有關詳細資訊,請參閱
torch.optim.Optimizer.zero_grad()。