快捷方式

AddThinkingPrompt

class torchrl.envs.llm.transforms.AddThinkingPrompt(cond: Callable[[TensorDictBase], bool], prompt: str | None = None, random_prompt: bool = False, role: Literal['user', 'assistant'] = 'assistant', edit_last_turn: bool = True, zero_reward: bool | None = None, undo_done: bool = True, egocentric: bool | None = None)[原始碼]

A transform that adds thinking prompts to encourage the LLM to reconsider its response. (一個新增思考提示以鼓勵 LLM 重新考慮其響應的轉換。)

This transform can either add a new thinking prompt as a separate message or edit the last assistant response to include a thinking prompt before the final answer. This is useful for training LLMs to self-correct and think more carefully when their initial responses are incorrect or incomplete. (此轉換可以新增一個新的思考提示作為單獨的訊息,也可以編輯最後一個助理解釋以在最終答案之前包含思考提示。這對於訓練 LLM 在初始響應不正確或不完整時進行自我糾正和更仔細地思考非常有用。)

引數:
  • cond (Callable[[TensorDictBase], bool], optional) – Condition function that determines when to add the thinking prompt. Takes a tensordict and returns True if the prompt should be added. (條件函式,用於確定何時新增思考提示。接收一個 tensordict 並返回 True,如果應新增提示。)

  • prompt (str, optional) – The thinking prompt to add. If None, a default prompt is used. Defaults to “But wait, let me think about this more carefully…”. (要新增的思考提示。如果為 None,則使用預設提示。預設為 “等等,讓我更仔細地考慮一下……”。)

  • random_prompt (bool, optional) – Whether to randomly select from predefined prompts. Defaults to False. (是否從預定義提示中隨機選擇。預設為 False。)

  • role (Literal["user", "assistant"], optional) – The role for the thinking prompt. If “assistant”, the prompt is added to the assistant’s response. If “user”, it’s added as a separate user message. Defaults to “assistant”. (思考提示的角色。如果為 “assistant”,則提示將新增到助理解釋中。如果為 “user”,則將提示新增為單獨的使用者訊息。預設為 “assistant”。)

  • edit_last_turn (bool, optional) – Whether to edit the last assistant response instead of adding a new message. Only works with role=”assistant”. Defaults to True. (是否編輯最後一個助理解釋而不是新增新訊息。僅在 role=”assistant” 時有效。預設為 True。)

  • zero_reward (bool, optional) – Whether to zero out the reward when the thinking prompt is added. If None, defaults to the value of edit_last_turn. Defaults to the same value as edit_last_turn. (新增思考提示時是否將獎勵清零。如果為 None,則預設為 edit_last_turn 的值。預設為與 edit_last_turn 相同的值。)

  • undo_done (bool, optional) – Whether to undo the done flag when the thinking prompt is added. Defaults to True. (新增思考提示時是否撤銷 done 標誌。預設為 True。)

  • egocentric (bool, optional) – Whether the thinking prompt is written from the perspective of the assistant. Defaults to None, which means that the prompt is written from the perspective of the user if role=”user” and from the perspective of the assistant if role=”assistant”. (思考提示是否從助手的角度編寫。預設為 None,這意味著如果 role=”user”,則提示將從使用者的角度編寫;如果 role=”assistant”,則從助手的角度編寫。)

示例

>>> from torchrl.envs.llm.transforms import AddThinkingPrompt
>>> from torchrl.envs.llm import GSM8KEnv
>>> from transformers import AutoTokenizer
>>> import torch
>>>
>>> # Create environment with thinking prompt transform
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
>>> env = GSM8KEnv(tokenizer=tokenizer, max_steps=10)
>>> env = env.append_transform(
...     AddThinkingPrompt(
...         cond=lambda td: td["reward"] < 50,
...         role="assistant",
...         edit_last_turn=True,
...         zero_reward=True,
...         undo_done=True
...     )
... )
>>>
>>> # Test with wrong answer (low reward)
>>> reset = env.reset()
>>> wrong_answer = (
...     "<think>Let me solve this step by step. Natalia sold clips to 48 friends in April. "
...     "Then she sold half as many in May. Half of 48 is 24. So in May she sold 24 clips. "
...     "To find the total, I need to add April and May: 48 + 24 = 72. "
...     "Therefore, Natalia sold 72 clips altogether in April and May.</think>"
...     "<answer>322 clips</answer><|im_end|>"
... )
>>> reset["text_response"] = [wrong_answer]
>>> s = env.step(reset)
>>> assert (s["next", "reward"] == 0).all()  # Reward zeroed
>>> assert (s["next", "done"] == 0).all()    # Done undone
>>> assert s["next", "history"].shape == (1, 3)  # History modified
>>>
>>> # Test with correct answer (high reward)
>>> reset = env.reset()
>>> correct_answer = (
...     "<think>Let me solve this step by step. Natalia sold clips to 48 friends in April. "
...     "Then she sold half as many in May. Half of 48 is 24. So in May she sold 24 clips. "
...     "To find the total, I need to add April and May: 48 + 24 = 72. "
...     "Therefore, Natalia sold 72 clips altogether in April and May.</think>"
...     "<answer>72</answer><|im_end|>"
... )
>>> reset["text_response"] = [correct_answer]
>>> s = env.step(reset)
>>> assert (s["next", "reward"] != 0).all()  # Reward not zeroed
>>> assert s["next", "done"].all()           # Done remains True
>>> assert s["next", "history"].shape == (1, 3)  # History unchanged
add_module(name: str, module: Optional[Module]) None

將子模組新增到當前模組。

可以使用給定的名稱作為屬性訪問該模組。

引數:
  • name (str) – 子模組的名稱。子模組可以透過給定名稱從此模組訪問

  • module (Module) – 要新增到模組中的子模組。

apply(fn: Callable[[Module], None]) Self

fn 遞迴應用於每個子模組(由 .children() 返回)以及自身。

典型用法包括初始化模型引數(另請參閱 torch.nn.init)。

引數:

fn (Module -> None) – 要應用於每個子模組的函式

返回:

self

返回型別:

模組

示例

>>> @torch.no_grad()
>>> def init_weights(m):
>>>     print(m)
>>>     if type(m) == nn.Linear:
>>>         m.weight.fill_(1.0)
>>>         print(m.weight)
>>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
>>> net.apply(init_weights)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)
bfloat16() Self

將所有浮點引數和緩衝區轉換為 bfloat16 資料型別。

注意

此方法就地修改模組。

返回:

self

返回型別:

模組

buffers(recurse: bool = True) Iterator[Tensor]

返回模組緩衝區的迭代器。

引數:

recurse (bool) – 如果為 True,則會產生此模組及其所有子模組的 buffer。否則,僅會產生此模組的直接成員 buffer。

產生:

torch.Tensor – 模組緩衝區

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> for buf in model.buffers():
>>>     print(type(buf), buf.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
children() Iterator[Module]

返回直接子模組的迭代器。

產生:

Module – 子模組

close()

關閉轉換。

property collector: DataCollectorBase | None

返回與容器關聯的收集器(如果存在)。

每當變換需要了解收集器或與之關聯的策略時,都可以使用此屬性。請確保僅在未巢狀在子程序中的變換上呼叫此屬性。收集器引用不會傳遞給 ParallelEnv 或類似的批處理環境的 worker。

請確保僅在未巢狀在子程序中的轉換上呼叫此屬性。 Collector 引用不會傳遞給 ParallelEnv 或類似批次環境的 worker。

compile(*args, **kwargs)

使用 torch.compile() 編譯此 Module 的前向傳播。

此 Module 的 __call__ 方法將被編譯,並且所有引數將按原樣傳遞給 torch.compile()

有關此函式的引數的詳細資訊,請參閱 torch.compile()

property container: EnvBase | None

返回包含該變換的環境。

示例

>>> from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter
>>> from torchrl.envs.libs.gym import GymEnv
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), Compose(RewardSum(), StepCounter()))
>>> env.transform[0].container is env
True
cpu() Self

將所有模型引數和緩衝區移動到 CPU。

注意

此方法就地修改模組。

返回:

self

返回型別:

模組

cuda(device: Optional[Union[device, int]] = None) Self

將所有模型引數和緩衝區移動到 GPU。

這也會使相關的引數和緩衝區成為不同的物件。因此,如果模組在最佳化時將駐留在 GPU 上,則應在構建最佳化器之前呼叫此函式。

注意

此方法就地修改模組。

引數:

device (int, optional) – 如果指定,所有引數將複製到該裝置

返回:

self

返回型別:

模組

double() Self

將所有浮點引數和緩衝區轉換為 double 資料型別。

注意

此方法就地修改模組。

返回:

self

返回型別:

模組

eval() Self

將模組設定為評估模式。

這僅對某些模組有影響。有關模組在訓練/評估模式下的行為,例如它們是否受影響(如 DropoutBatchNorm 等),請參閱具體模組的文件。

這等同於 self.train(False)

有關 .eval() 和幾種可能與之混淆的類似機制之間的比較,請參閱 區域性停用梯度計算

返回:

self

返回型別:

模組

extra_repr() str

返回模組的額外表示。

要列印自定義額外資訊,您應該在自己的模組中重新實現此方法。單行和多行字串均可接受。

float() Self

將所有浮點引數和緩衝區轉換為 float 資料型別。

注意

此方法就地修改模組。

返回:

self

返回型別:

模組

forward(tensordict: TensorDictBase = None) TensorDictBase

讀取輸入 tensordict,並對選定的鍵應用轉換。

預設情況下,此方法

  • 直接呼叫 _apply_transform()

  • 不呼叫 _step()_call()

此方法不會在任何時候在 env.step 中呼叫。但是,它會在 sample() 中呼叫。

注意

forward 也可以使用 dispatch 將引數名稱轉換為鍵,並使用常規關鍵字引數。

示例

>>> class TransformThatMeasuresBytes(Transform):
...     '''Measures the number of bytes in the tensordict, and writes it under `"bytes"`.'''
...     def __init__(self):
...         super().__init__(in_keys=[], out_keys=["bytes"])
...
...     def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
...         bytes_in_td = tensordict.bytes()
...         tensordict["bytes"] = bytes
...         return tensordict
>>> t = TransformThatMeasuresBytes()
>>> env = env.append_transform(t) # works within envs
>>> t(TensorDict(a=0))  # Works offline too.
get_buffer(target: str) Tensor

返回由 target 給定的緩衝區(如果存在),否則丟擲錯誤。

有關此方法功能的更詳細解釋以及如何正確指定 target,請參閱 get_submodule 的文件字串。

引數:

target – 要查詢的 buffer 的完全限定字串名稱。(要指定完全限定字串,請參閱 get_submodule。)

返回:

target 引用的緩衝區

返回型別:

torch.Tensor

丟擲:

AttributeError – 如果目標字串引用了無效路徑或解析為非 buffer 物件。

get_extra_state() Any

返回要包含在模組 state_dict 中的任何額外狀態。

Implement this and a corresponding set_extra_state() for your module if you need to store extra state. This function is called when building the module’s state_dict(). (如果需要儲存額外狀態,請實現此函式和相應的 set_extra_state() 函式。在構建模組的 state_dict() 時會呼叫此函式。)

注意,為了保證 state_dict 的序列化工作正常,額外狀態應該是可被 pickle 的。我們僅為 Tensors 的序列化提供向後相容性保證;其他物件的序列化形式若發生變化,可能導致向後相容性中斷。

返回:

要儲存在模組 state_dict 中的任何額外狀態

返回型別:

物件

get_parameter(target: str) Parameter

如果存在,返回由 target 給定的引數,否則丟擲錯誤。

有關此方法功能的更詳細解釋以及如何正確指定 target,請參閱 get_submodule 的文件字串。

引數:

target – 要查詢的 Parameter 的完全限定字串名稱。(要指定完全限定字串,請參閱 get_submodule。)

返回:

target 引用的引數

返回型別:

torch.nn.Parameter

丟擲:

AttributeError – 如果目標字串引用了無效路徑或解析為非 nn.Parameter 的物件。

get_submodule(target: str) Module

如果存在,返回由 target 給定的子模組,否則丟擲錯誤。

例如,假設您有一個 nn.Module A,它看起來像這樣

A(
    (net_b): Module(
        (net_c): Module(
            (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
        )
        (linear): Linear(in_features=100, out_features=200, bias=True)
    )
)

(圖示了一個 nn.Module AA 包含一個巢狀子模組 net_b,該子模組本身有兩個子模組 net_clinearnet_c 隨後又有一個子模組 conv。)

要檢查是否存在 linear 子模組,可以呼叫 get_submodule("net_b.linear")。要檢查是否存在 conv 子模組,可以呼叫 get_submodule("net_b.net_c.conv")

get_submodule 的執行時複雜度受 target 中模組巢狀深度的限制。與 named_modules 的查詢相比,後者的複雜度是按傳遞模組數量計算的 O(N)。因此,對於簡單地檢查某個子模組是否存在,應始終使用 get_submodule

引數:

target – 要查詢的子模組的完全限定字串名稱。(要指定完全限定字串,請參閱上面的示例。)

返回:

target 引用的子模組

返回型別:

torch.nn.Module

丟擲:

AttributeError – 如果在目標字串解析的任何路徑中,子路徑解析為不存在的屬性名或不是 nn.Module 例項的物件。

half() Self

將所有浮點引數和緩衝區轉換為 half 資料型別。

注意

此方法就地修改模組。

返回:

self

返回型別:

模組

init(tensordict) None

執行轉換的初始化步驟。

inv(tensordict: TensorDictBase = None) TensorDictBase

讀取輸入 tensordict,並對選定的鍵應用逆變換。

預設情況下,此方法

  • 直接呼叫 _inv_apply_transform()

  • 不呼叫 _inv_call()

注意

inv 也透過使用 dispatch 將引數名稱強制轉換為鍵來處理常規關鍵字引數。

注意

invextend() 呼叫。

ipu(device: Optional[Union[device, int]] = None) Self

將所有模型引數和緩衝區移動到 IPU。

這也會使關聯的引數和緩衝區成為不同的物件。因此,如果模組在最佳化時將駐留在 IPU 上,則應在構建最佳化器之前呼叫它。

注意

此方法就地修改模組。

引數:

device (int, optional) – 如果指定,所有引數將複製到該裝置

返回:

self

返回型別:

模組

load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False)

Copy parameters and buffers from state_dict into this module and its descendants. (從 state_dict 複製引數和緩衝區到此模組及其子模組中。)

If strict is True, then the keys of state_dict must match the keys returned by this module’s state_dict() function exactly. (如果 strictTrue,則 state_dict 的鍵必須與此模組的 state_dict() 函式返回的鍵完全匹配。)

警告

If assign is True the optimizer must be created after the call to load_state_dict unless get_swap_module_params_on_conversion() is True. (如果 assignTrue,則必須在呼叫 load_state_dict 之後建立最佳化器,除非 get_swap_module_params_on_conversion()True。)

引數:
  • state_dict (dict) – 包含引數和持久 buffer 的字典。

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True (是否嚴格強制 state_dict 中的鍵與此模組的 state_dict() 函式返回的鍵匹配。預設值:True。)

  • assign (bool, optional) – 當設定為 False 時,將保留當前模組中張量的屬性;當設定為 True 時,將保留 state_dict 中張量的屬性。唯一的例外是 Parameterrequires_grad 欄位,此時將保留模組的值。預設值:False

返回:

  • missing_keys 是一個包含此模組期望但

    在提供的 state_dict 中缺失的任何鍵的字串列表。

  • unexpected_keys 是一個字串列表,包含此模組

    不期望但在提供的 state_dict 中存在的鍵。

返回型別:

NamedTuple,包含 missing_keysunexpected_keys 欄位。

注意

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError. (如果引數或緩衝區註冊為 None 並且其對應的鍵存在於 state_dict 中,load_state_dict() 將引發 RuntimeError。)

modules() Iterator[Module]

返回網路中所有模組的迭代器。

產生:

Module – 網路中的一個模組

注意

重複的模組只返回一次。在以下示例中,l 只返回一次。

示例

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.modules()):
...     print(idx, '->', m)

0 -> Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
)
1 -> Linear(in_features=2, out_features=2, bias=True)
mtia(device: Optional[Union[device, int]] = None) Self

將所有模型引數和緩衝區移動到 MTIA。

這也會使關聯的引數和緩衝區成為不同的物件。因此,如果模組在最佳化時將駐留在 MTIA 上,則應在構建最佳化器之前呼叫它。

注意

此方法就地修改模組。

引數:

device (int, optional) – 如果指定,所有引數將複製到該裝置

返回:

self

返回型別:

模組

named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[tuple[str, torch.Tensor]]

返回模組緩衝區上的迭代器,同時生成緩衝區的名稱和緩衝區本身。

引數:
  • prefix (str) – 為所有 buffer 名稱新增字首。

  • recurse (bool, optional) – 如果為 True,則會生成此模組及其所有子模組的 buffers。否則,僅生成此模組直接成員的 buffers。預設為 True。

  • remove_duplicate (bool, optional) – 是否在結果中刪除重複的 buffers。預設為 True。

產生:

(str, torch.Tensor) – 包含名稱和緩衝區的元組

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>>     if name in ['running_var']:
>>>         print(buf.size())
named_children() Iterator[tuple[str, 'Module']]

返回對直接子模組的迭代器,生成模組的名稱和模組本身。

產生:

(str, Module) – 包含名稱和子模組的元組

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, module in model.named_children():
>>>     if name in ['conv4', 'conv5']:
>>>         print(module)
named_modules(memo: Optional[set['Module']] = None, prefix: str = '', remove_duplicate: bool = True)

返回網路中所有模組的迭代器,同時生成模組的名稱和模組本身。

引數:
  • memo – 用於儲存已新增到結果中的模組集合的 memo

  • prefix – 將新增到模組名稱的名稱字首

  • remove_duplicate – 是否從結果中刪除重複的模組例項

產生:

(str, Module) – 名稱和模組的元組

注意

重複的模組只返回一次。在以下示例中,l 只返回一次。

示例

>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.named_modules()):
...     print(idx, '->', m)

0 -> ('', Sequential(
  (0): Linear(in_features=2, out_features=2, bias=True)
  (1): Linear(in_features=2, out_features=2, bias=True)
))
1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[tuple[str, torch.nn.parameter.Parameter]]

返回模組引數的迭代器,同時生成引數的名稱和引數本身。

引數:
  • prefix (str) – 為所有引數名稱新增字首。

  • recurse (bool) – 如果為 True,則會生成此模組及其所有子模組的引數。否則,僅生成此模組直接成員的引數。

  • remove_duplicate (bool, optional) – 是否在結果中刪除重複的引數。預設為 True。

產生:

(str, Parameter) – 包含名稱和引數的元組

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
parameters(recurse: bool = True) Iterator[Parameter]

返回模組引數的迭代器。

這通常傳遞給最佳化器。

引數:

recurse (bool) – 如果為 True,則會生成此模組及其所有子模組的引數。否則,僅生成此模組直接成員的引數。

產生:

Parameter – 模組引數

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> for param in model.parameters():
>>>     print(type(param), param.size())
<class 'torch.Tensor'> (20L,)
<class 'torch.Tensor'> (20L, 1L, 5L, 5L)
property parent: TransformedEnv | None

返回變換的父環境。

父環境是包含直到當前變換的所有變換的環境。

示例

>>> from torchrl.envs import TransformedEnv, Compose, RewardSum, StepCounter
>>> from torchrl.envs.libs.gym import GymEnv
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), Compose(RewardSum(), StepCounter()))
>>> env.transform[1].parent
TransformedEnv(
    env=GymEnv(env=Pendulum-v1, batch_size=torch.Size([]), device=cpu),
    transform=Compose(
            RewardSum(keys=['reward'])))
register_backward_hook(hook: Callable[[Module, Union[tuple[torch.Tensor, ...], Tensor], Union[tuple[torch.Tensor, ...], Tensor]]) RemovableHandle

在模組上註冊一個反向傳播鉤子。

此函式已棄用,建議使用 register_full_backward_hook(),並且此函式在未來版本中的行為將發生變化。

返回:

一個控制代碼,可用於透過呼叫 handle.remove() 來移除新增的鉤子

返回型別:

torch.utils.hooks.RemovableHandle

register_buffer(name: str, tensor: Optional[Tensor], persistent: bool = True) None

向模組新增一個緩衝區。

This is typically used to register a buffer that should not be considered a model parameter. For example, BatchNorm’s running_mean is not a parameter, but is part of the module’s state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting persistent to False. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module’s state_dict. (這通常用於註冊不應被視為模型引數的緩衝區。例如,BatchNorm 的 running_mean 不是引數,但它是模組狀態的一部分。預設情況下,緩衝區是持久的,並且將與引數一起儲存。透過將 persistent 設定為 False 可以更改此行為。持久緩衝區和非持久緩衝區之間的唯一區別是後者將不包含在此模組的 state_dict 中。)

可以使用給定名稱作為屬性訪問緩衝區。

引數:
  • name (str) – buffer 的名稱。可以使用給定的名稱從此模組訪問 buffer

  • tensor (Tensor or None) – buffer to be registered. If None, then operations that run on buffers, such as cuda, are ignored. If None, the buffer is not included in the module’s state_dict. (要註冊的緩衝區。如果為 None,則在緩衝區上執行的操作,例如 cuda,將被忽略。如果為 None,則該緩衝區包含在模組的 state_dict 中。)

  • persistent (bool) – whether the buffer is part of this module’s state_dict. (緩衝區是否是此模組 state_dict 的一部分。)

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> self.register_buffer('running_mean', torch.zeros(num_features))
register_forward_hook(hook: Union[Callable[[T, tuple[Any, ...]], Any], Callable[[T, tuple[Any, ...], dict[str, Any]], Any]], *, prepend: bool = False, with_kwargs: bool = False, always_call: bool = False) RemovableHandle

在模組上註冊一個前向鉤子。

The hook will be called every time after forward() has computed an output. (每次在 forward() 計算完輸出後都會呼叫此鉤子。)

If with_kwargs is False or not specified, the input contains only the positional arguments given to the module. Keyword arguments won’t be passed to the hooks and only to the forward. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after forward() is called. The hook should have the following signature (如果 with_kwargsFalse 或未指定,則輸入僅包含傳遞給模組的位置引數。關鍵字引數不會傳遞給鉤子,只會傳遞給 forward。鉤子可以修改輸出。它可以原地修改輸入,但這不會影響 forward,因為此操作是在呼叫 forward() 之後呼叫的。鉤子應具有以下簽名)

hook(module, args, output) -> None or modified output

如果 with_kwargsTrue,則前向鉤子將接收傳遞給 forward 函式的 kwargs,並需要返回可能已修改的輸出。鉤子應該具有以下簽名

hook(module, args, kwargs, output) -> None or modified output
引數:
  • hook (Callable) – 使用者定義的待註冊鉤子。

  • prepend (bool) – If True, the provided hook will be fired before all existing forward hooks on this torch.nn.Module. Otherwise, the provided hook will be fired after all existing forward hooks on this torch.nn.Module. Note that global forward hooks registered with register_module_forward_hook() will fire before all hooks registered by this method. Default: False (如果為 True,則提供的 hook 將在對此 torch.nn.Module 的所有現有 forward 鉤子之前觸發。否則,提供的 hook 將在此 torch.nn.Module 的所有現有 forward 鉤子之後觸發。請注意,使用 register_module_forward_hook() 註冊的全域性 forward 鉤子將在透過此方法註冊的所有鉤子之前觸發。預設值:False。)

  • with_kwargs (bool) – 如果為 True,則 hook 將接收傳遞給 forward 函式的 kwargs。預設為 False

  • always_call (bool) – 如果為 True,則無論在呼叫 Module 時是否引發異常,都會執行 hook。預設為 False

返回:

一個控制代碼,可用於透過呼叫 handle.remove() 來移除新增的鉤子

返回型別:

torch.utils.hooks.RemovableHandle

register_forward_pre_hook(hook: Union[Callable[[T, tuple[Any, ...]], Optional[Any]], Callable[[T, tuple[Any, ...], dict[str, Any]], Optional[tuple[Any, dict[str, Any]]]]], *, prepend: bool = False, with_kwargs: bool = False) RemovableHandle

在模組上註冊一個前向預鉤子。

The hook will be called every time before forward() is invoked. (每次在呼叫 forward() 之前都會呼叫此鉤子。)

如果 with_kwargs 為 false 或未指定,則輸入僅包含傳遞給模組的位置引數。關鍵字引數不會傳遞給鉤子,而只會傳遞給 forward。鉤子可以修改輸入。使用者可以返回一個元組或單個修改後的值。我們將把值包裝成一個元組,如果返回的是單個值(除非該值本身就是元組)。鉤子應該具有以下簽名

hook(module, args) -> None or modified input

如果 with_kwargs 為 true,則前向預鉤子將接收傳遞給 forward 函式的 kwargs。如果鉤子修改了輸入,則應該返回 args 和 kwargs。鉤子應該具有以下簽名

hook(module, args, kwargs) -> None or a tuple of modified input and kwargs
引數:
  • hook (Callable) – 使用者定義的待註冊鉤子。

  • prepend (bool) – If true, the provided hook will be fired before all existing forward_pre hooks on this torch.nn.Module. Otherwise, the provided hook will be fired after all existing forward_pre hooks on this torch.nn.Module. Note that global forward_pre hooks registered with register_module_forward_pre_hook() will fire before all hooks registered by this method. Default: False (如果為 True,則提供的 hook 將在對此 torch.nn.Module 的所有現有 forward_pre 鉤子之前觸發。否則,提供的 hook 將在此 torch.nn.Module 的所有現有 forward_pre 鉤子之後觸發。請注意,使用 register_module_forward_pre_hook() 註冊的全域性 forward_pre 鉤子將在透過此方法註冊的所有鉤子之前觸發。預設值:False。)

  • with_kwargs (bool) – 如果為 True,則 hook 將接收傳遞給 forward 函式的 kwargs。預設為 False

返回:

一個控制代碼,可用於透過呼叫 handle.remove() 來移除新增的鉤子

返回型別:

torch.utils.hooks.RemovableHandle

register_full_backward_hook(hook: Callable[[Module, Union[tuple[torch.Tensor, ...], Tensor], Union[tuple[torch.Tensor, ...], Tensor]], prepend: bool = False) RemovableHandle

在模組上註冊一個反向傳播鉤子。

每次計算相對於模組的梯度時,將呼叫此鉤子,其觸發規則如下:

  1. 通常,鉤子在計算相對於模組輸入的梯度時觸發。

  2. 如果模組輸入都不需要梯度,則在計算相對於模組輸出的梯度時觸發鉤子。

  3. 如果模組輸出都不需要梯度,則鉤子將不觸發。

鉤子應具有以下簽名

hook(module, grad_input, grad_output) -> tuple(Tensor) or None

The grad_input and grad_output are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of grad_input in subsequent computations. grad_input will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in grad_input and grad_output will be None for all non-Tensor arguments. ( grad_inputgrad_output 是包含相對於輸入和輸出的梯度的元組。鉤子不應修改其引數,但它可以選擇返回一個相對於輸入的新的梯度,該梯度將在後續計算中替代 grad_inputgrad_input 將僅對應於作為位置引數給出的輸入,並且所有關鍵字引數都將被忽略。對於所有非 Tensor 引數,grad_inputgrad_output 中的條目將為 None。)

由於技術原因,當此鉤子應用於模組時,其前向函式將接收傳遞給模組的每個張量的檢視。類似地,呼叫者將接收模組前向函式返回的每個張量的檢視。

警告

使用反向傳播鉤子時不允許就地修改輸入或輸出,否則將引發錯誤。

引數:
  • hook (Callable) – 要註冊的使用者定義鉤子。

  • prepend (bool) – If true, the provided hook will be fired before all existing backward hooks on this torch.nn.Module. Otherwise, the provided hook will be fired after all existing backward hooks on this torch.nn.Module. Note that global backward hooks registered with register_module_full_backward_hook() will fire before all hooks registered by this method. (如果為 True,則提供的 hook 將在對此 torch.nn.Module 的所有現有 backward 鉤子之前觸發。否則,提供的 hook 將在此 torch.nn.Module 的所有現有 backward 鉤子之後觸發。請注意,使用 register_module_full_backward_hook() 註冊的全域性 backward 鉤子將在透過此方法註冊的所有鉤子之前觸發。)

返回:

一個控制代碼,可用於透過呼叫 handle.remove() 來移除新增的鉤子

返回型別:

torch.utils.hooks.RemovableHandle

register_full_backward_pre_hook(hook: Callable[[Module, Union[tuple[torch.Tensor, ...], Tensor]], Union[None, tuple[torch.Tensor, ...], Tensor]], prepend: bool = False) RemovableHandle

在模組上註冊一個反向預鉤子。

每次計算模組的梯度時,將呼叫此鉤子。鉤子應具有以下簽名

hook(module, grad_output) -> tuple[Tensor] or None

grad_output 是一個元組。鉤子不應修改其引數,但可以選擇返回一個新的輸出梯度,該梯度將取代 grad_output 用於後續計算。對於所有非 Tensor 引數,grad_output 中的條目將為 None

由於技術原因,當此鉤子應用於模組時,其前向函式將接收傳遞給模組的每個張量的檢視。類似地,呼叫者將接收模組前向函式返回的每個張量的檢視。

警告

使用反向傳播鉤子時不允許就地修改輸入,否則將引發錯誤。

引數:
  • hook (Callable) – 要註冊的使用者定義鉤子。

  • prepend (bool) – If true, the provided hook will be fired before all existing backward_pre hooks on this torch.nn.Module. Otherwise, the provided hook will be fired after all existing backward_pre hooks on this torch.nn.Module. Note that global backward_pre hooks registered with register_module_full_backward_pre_hook() will fire before all hooks registered by this method. (如果為 True,則提供的 hook 將在對此 torch.nn.Module 的所有現有 backward_pre 鉤子之前觸發。否則,提供的 hook 將在此 torch.nn.Module 的所有現有 backward_pre 鉤子之後觸發。請注意,使用 register_module_full_backward_pre_hook() 註冊的全域性 backward_pre 鉤子將在透過此方法註冊的所有鉤子之前觸發。)

返回:

一個控制代碼,可用於透過呼叫 handle.remove() 來移除新增的鉤子

返回型別:

torch.utils.hooks.RemovableHandle

register_load_state_dict_post_hook(hook)

註冊一個後鉤子,用於在模組的 load_state_dict() 被呼叫後執行。

它應該具有以下簽名:

hook(module, incompatible_keys) -> None

The module argument is the current module that this hook is registered on, and the incompatible_keys argument is a NamedTuple consisting of attributes missing_keys and unexpected_keys. missing_keys is a list of str containing the missing keys and unexpected_keys is a list of str containing the unexpected keys. ( module 引數是當前註冊此鉤子的模組,而 incompatible_keys 引數是包含 missing_keysunexpected_keys 屬性的 NamedTuplemissing_keys 是一個包含缺失鍵的 str 列表,而 unexpected_keys 是一個包含意外部索引鍵的 str 列表。)

如果需要,可以就地修改給定的 incompatible_keys。

Note that the checks performed when calling load_state_dict() with strict=True are affected by modifications the hook makes to missing_keys or unexpected_keys, as expected. Additions to either set of keys will result in an error being thrown when strict=True, and clearing out both missing and unexpected keys will avoid an error. (請注意,當以 strict=True 呼叫 load_state_dict() 時執行的檢查會受到鉤子對 missing_keysunexpected_keys 的修改的影響,正如預期的那樣。向任一鍵集新增內容將導致在 strict=True 時丟擲錯誤,而清空缺失和意外部索引鍵將避免錯誤。)

返回:

一個控制代碼,可用於透過呼叫 handle.remove() 來移除新增的鉤子

返回型別:

torch.utils.hooks.RemovableHandle

register_load_state_dict_pre_hook(hook)

註冊一個預鉤子,用於在模組的 load_state_dict() 被呼叫之前執行。

它應該具有以下簽名:

hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -> None # noqa: B950

引數:

hook (Callable) – 在載入狀態字典之前將呼叫的可呼叫鉤子。

register_module(name: str, module: Optional[Module]) None

Alias for add_module(). ( add_module() 的別名。)

register_parameter(name: str, param: Optional[Parameter]) None

向模組新增一個引數。

可以使用給定名稱作為屬性訪問該引數。

引數:
  • name (str) – 引數的名稱。可以透過給定名稱從該模組訪問該引數。

  • param (Parameter or None) – parameter to be added to the module. If None, then operations that run on parameters, such as cuda, are ignored. If None, the parameter is not included in the module’s state_dict. (要新增到模組的引數。如果為 None,則在引數上執行的操作,例如 cuda,將被忽略。如果為 None,則該引數包含在模組的 state_dict 中。)

register_state_dict_post_hook(hook)

註冊 state_dict() 方法的後置鉤子。

它應該具有以下簽名:

hook(module, state_dict, prefix, local_metadata) -> None

註冊的鉤子可以就地修改 state_dict

register_state_dict_pre_hook(hook)

註冊 state_dict() 方法的前置鉤子。

它應該具有以下簽名:

hook(module, prefix, keep_vars) -> None

註冊的鉤子可用於在進行 state_dict 呼叫之前執行預處理。

requires_grad_(requires_grad: bool = True) Self

更改自動梯度是否應記錄此模組中引數的操作。

此方法就地設定引數的 requires_grad 屬性。

此方法有助於凍結模組的一部分以進行微調或單獨訓練模型的一部分(例如,GAN 訓練)。

請參閱 本地停用梯度計算 以比較 .requires_grad_() 和幾種可能與之混淆的類似機制。

引數:

requires_grad (bool) – 自動求導是否應記錄此模組上的引數操作。預設為 True

返回:

self

返回型別:

模組

set_extra_state(state: Any) None

設定載入的 state_dict 中包含的額外狀態。

This function is called from load_state_dict() to handle any extra state found within the state_dict. Implement this function and a corresponding get_extra_state() for your module if you need to store extra state within its state_dict. (此函式由 load_state_dict() 呼叫,用於處理 state_dict 中的任何額外狀態。如果需要在此模組的 state_dict 中儲存額外狀態,請實現此函式和相應的 get_extra_state() 函式。)

引數:

state (dict) – 來自 state_dict 的額外狀態

set_submodule(target: str, module: Module, strict: bool = False) None

如果存在,設定由 target 給定的子模組,否則丟擲錯誤。

注意

如果 strict 設定為 False(預設),該方法將替換現有子模組或在父模組存在的情況下建立新子模組。如果 strict 設定為 True,該方法將僅嘗試替換現有子模組,並在子模組不存在時引發錯誤。

例如,假設您有一個 nn.Module A,它看起來像這樣

A(
    (net_b): Module(
        (net_c): Module(
            (conv): Conv2d(3, 3, 3)
        )
        (linear): Linear(3, 3)
    )
)

(圖示了一個 nn.Module AA 包含一個巢狀子模組 net_b,該子模組本身有兩個子模組 net_clinearnet_c 隨後又有一個子模組 conv。)

要用一個新的 Linear 子模組覆蓋 Conv2d,可以呼叫 set_submodule("net_b.net_c.conv", nn.Linear(1, 1)),其中 strict 可以是 TrueFalse

要將一個新的 Conv2d 子模組新增到現有的 net_b 模組中,可以呼叫 set_submodule("net_b.conv", nn.Conv2d(1, 1, 1))

在上面,如果設定 strict=True 並呼叫 set_submodule("net_b.conv", nn.Conv2d(1, 1, 1), strict=True),則會引發 AttributeError,因為 net_b 中不存在名為 conv 的子模組。

引數:
  • target – 要查詢的子模組的完全限定字串名稱。(要指定完全限定字串,請參閱上面的示例。)

  • module – 要設定子模組的物件。

  • strict – 如果為 False,該方法將替換現有子模組或建立新子模組(如果父模組存在)。如果為 True,則該方法只會嘗試替換現有子模組,如果子模組不存在則丟擲錯誤。

丟擲:
  • ValueError – 如果 target 字串為空或 module 不是 nn.Module 的例項。

  • AttributeError – 如果 target 字串路徑中的任何一點解析為一個不存在的屬性名或不是 nn.Module 例項的物件。

share_memory() Self

請參閱 torch.Tensor.share_memory_()

state_dict(*args, destination=None, prefix='', keep_vars=False)

返回一個字典,其中包含對模組整個狀態的引用。

引數和持久緩衝區(例如,執行平均值)都包含在內。鍵是相應的引數和緩衝區名稱。設定為 None 的引數和緩衝區不包含在內。

注意

返回的物件是淺複製。它包含對模組引數和緩衝區的引用。

警告

當前 state_dict() 還接受 destinationprefixkeep_vars 的位置引數,順序為。但是,這正在被棄用,並且在未來的版本中將強制使用關鍵字引數。

警告

請避免使用引數 destination,因為它不是為終端使用者設計的。

引數:
  • destination (dict, optional) – 如果提供,模組的狀態將更新到 dict 中,並返回相同的物件。否則,將建立一個 OrderedDict 並返回。預設為 None

  • prefix (str, optional) – a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''

  • keep_vars (bool, optional) – 預設情況下,state dict 中返回的 Tensors 會從 autograd 中分離。如果設定為 True,則不會執行分離。預設為 False

返回:

包含模組整體狀態的字典

返回型別:

dict

示例

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
to(*args, **kwargs) Transform

移動和/或轉換引數和緩衝區。

這可以這樣呼叫

to(device=None, dtype=None, non_blocking=False)
to(dtype, non_blocking=False)
to(tensor, non_blocking=False)
to(memory_format=torch.channels_last)

Its signature is similar to torch.Tensor.to(), but only accepts floating point or complex dtypes. In addition, this method will only cast the floating point or complex parameters and buffers to dtype (if given). The integral parameters and buffers will be moved device, if that is given, but with dtypes unchanged. When non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. (其簽名與 torch.Tensor.to() 類似,但只接受浮點或複數 dtype。此外,此方法只會將浮點或複數引數和緩衝區轉換為(如果給定)dtype。如果給定了 device,整數引數和緩衝區將被移動到 device,但 dtype 不變。當設定 non_blocking 時,它會嘗試儘可能非同步地(相對於主機)轉換/移動,例如,將具有固定記憶體的 CPU Tensor 移動到 CUDA 裝置。)

有關示例,請參閱下文。

注意

此方法就地修改模組。

引數:
  • device (torch.device) – the desired device of the parameters and buffers in this module – 此模組中引數和緩衝區的目標裝置。

  • dtype (torch.dtype) – the desired floating point or complex dtype of the parameters and buffers in this module – 此模組中引數和緩衝區的目標浮點數或複數 dtype。

  • tensor (torch.Tensor) – Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module – 其 dtype 和 device 是此模組中所有引數和緩衝區的目標 dtype 和 device 的 Tensor。

  • memory_format (torch.memory_format) – the desired memory format for 4D parameters and buffers in this module (keyword only argument) – 此模組中 4D 引數和緩衝區的目標記憶體格式(僅關鍵字引數)。

返回:

self

返回型別:

模組

示例

>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]], dtype=torch.float64)
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16)

>>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
>>> linear.weight
Parameter containing:
tensor([[ 0.3741+0.j,  0.2382+0.j],
        [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
>>> linear(torch.ones(3, 2, dtype=torch.cdouble))
tensor([[0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
to_empty(*, device: Optional[Union[int, str, device]], recurse: bool = True) Self

將引數和緩衝區移動到指定裝置,而不復制儲存。

引數:
  • device (torch.device) – The desired device of the parameters and buffers in this module. – 此模組中引數和緩衝區的目標裝置。

  • recurse (bool) – 是否遞迴地將子模組的引數和緩衝區移動到指定裝置。

返回:

self

返回型別:

模組

train(mode: bool = True) Self

將模組設定為訓練模式。

This has an effect only on certain modules. See the documentation of particular modules for details of their behaviors in training/evaluation mode, i.e., whether they are affected, e.g. Dropout, BatchNorm, etc. – 這隻對某些模組有影響。有關其在訓練/評估模式下的行為的詳細資訊,例如它們是否受影響,請參閱特定模組的文件,例如 DropoutBatchNorm 等。

引數:

mode (bool) – whether to set training mode (True) or evaluation mode (False). Default: True. – 設定訓練模式(True)或評估模式(False)。預設值:True

返回:

self

返回型別:

模組

transform_action_spec(action_spec: TensorSpec) TensorSpec

轉換動作規範,使結果規範與變換對映匹配。

引數:

action_spec (TensorSpec) – 變換前的規範

返回:

轉換後的預期規範

transform_done_spec(done_spec: TensorSpec) TensorSpec

變換 done spec,使結果 spec 與變換對映匹配。

引數:

done_spec (TensorSpec) – 變換前的 spec

返回:

轉換後的預期規範

transform_env_batch_size(batch_size: Size) Size

轉換父環境的 batch-size。

transform_env_device(device: device) device

轉換父環境的 device。

transform_input_spec(input_spec: TensorSpec) TensorSpec

轉換輸入規範,使結果規範與轉換對映匹配。

引數:

input_spec (TensorSpec) – 轉換前的規範

返回:

轉換後的預期規範

transform_observation_spec(observation_spec: TensorSpec) TensorSpec

轉換觀察規範,使結果規範與轉換對映匹配。

引數:

observation_spec (TensorSpec) – 轉換前的規範

返回:

轉換後的預期規範

transform_output_spec(output_spec: Composite) Composite

轉換輸出規範,使結果規範與轉換對映匹配。

This method should generally be left untouched. Changes should be implemented using transform_observation_spec(), transform_reward_spec() and transform_full_done_spec(). :param output_spec: spec before the transform :type output_spec: TensorSpec (此方法通常應保持不變。更改應使用 transform_observation_spec()transform_reward_spec()transform_full_done_spec() 實現。 :param output_spec: 轉換前的 spec :type output_spec: TensorSpec)

返回:

轉換後的預期規範

transform_reward_spec(reward_spec: TensorSpec) TensorSpec

轉換獎勵的 spec,使其與變換對映匹配。

引數:

reward_spec (TensorSpec) – 變換前的 spec

返回:

轉換後的預期規範

transform_state_spec(state_spec: TensorSpec) TensorSpec

轉換狀態規範,使結果規範與變換對映匹配。

引數:

state_spec (TensorSpec) – 變換前的規範

返回:

轉換後的預期規範

type(dst_type: Union[dtype, str]) Self

將所有引數和緩衝區轉換為 dst_type

注意

此方法就地修改模組。

引數:

dst_type (type or string) – 目標型別

返回:

self

返回型別:

模組

xpu(device: Optional[Union[device, int]] = None) Self

將所有模型引數和緩衝區移動到 XPU。

這也會使關聯的引數和緩衝區成為不同的物件。因此,如果模組在最佳化時將駐留在 XPU 上,則應在構建最佳化器之前呼叫它。

注意

此方法就地修改模組。

引數:

device (int, optional) – 如果指定,所有引數將複製到該裝置

返回:

self

返回型別:

模組

zero_grad(set_to_none: bool = True) None

重置所有模型引數的梯度。

See similar function under torch.optim.Optimizer for more context. – 有關更多背景資訊,請參閱 torch.optim.Optimizer 下的類似函式。

引數:

set_to_none (bool) – instead of setting to zero, set the grads to None. See torch.optim.Optimizer.zero_grad() for details. – 與其設定為零,不如將 grad 設定為 None。有關詳細資訊,請參閱 torch.optim.Optimizer.zero_grad()

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源