快捷方式

LLM 介面

TorchRL 為 LLM 的訓練後和微調提供了全面的框架。LLM API 基於五個核心概念構建,這些概念共同作用,為語言模型建立完整的強化學習管道。

  1. 資料表示 (資料結構):用於處理對話、文字解析和 LLM 輸出類。這包括用於管理對話上下文的 History 類,以及用於 token、對數機率和文字的結構化輸出類。

  2. LLM 包裝器 API (模組):用於不同 LLM 後端的統一介面,包括用於 Hugging Face 模型的 TransformersWrapper,用於 vLLM 推理的 vLLMWrapper,以及用於高效能分散式 vLLM 推理(推薦)的 AsyncVLLM。這些包裝器在不同後端之間提供了一致的輸入/輸出格式,併為損失計算、資料儲存、評分、權重同步等提供了一個整合的介面。

  3. 環境 (環境):負責資料載入、工具執行、獎勵計算和格式化的編排層。這包括用於對話管理的 ChatEnv、資料集環境以及用於工具整合的各種轉換。

  4. 目標 (目標):用於 LLM 訓練的專用損失函式,包括用於組相對策略最佳化(Group Relative Policy Optimization)的 GRPOLoss 和用於監督微調(supervised fine-tuning)的 SFTLoss

  5. 收集器 (收集器):收集器用於從環境中收集資料並將其儲存為可用於訓練的格式。這包括用於從環境中收集資料的 LLMCollector 和使用 Ray 在分散式環境中收集資料的 RayLLMCollector

這些元件協同工作,構建一個完整的管道:環境負責載入和格式化資料,LLM 包裝器負責推理,資料結構負責維護對話上下文,目標負責計算訓練損失。模組化設計允許您根據具體用例混合搭配元件。

sota-implementations/grpo/ 目錄中可以找到一個使用 LLM API 的完整示例。訓練編排涉及三個主要元件:

  • 資料收集器:持有對環境和推理模型或引擎的引用。它收集資料,放入緩衝區,並處理權重更新。

  • 回放緩衝區:儲存收集到的資料並執行任何預處理或後處理步驟。這些可能包括:- 使用基於蒙特卡羅的方法進行優勢估計(使用 MCAdvantage 轉換);- 對輸出進行評分;- 日誌記錄等。

  • 訓練器:處理訓練迴圈,包括最佳化步驟、序列化、日誌記錄和權重更新初始化。

警告

LLM API 仍在開發中,未來可能會發生變化。歡迎提供反饋、報告問題和提交 PR!

資料結構

資料表示層為以結構化的方式處理對話和 LLM 輸出奠定了基礎。

History 類

與 Transformers 中通常存在的聊天格式(請參閱 Hugging Face 聊天文件)相比,History 類是 TensorClass 的版本。它提供了一個全面的 API 來管理對話資料,功能包括:

  • 文字解析和格式化:使用 from_text()apply_chat_template() 在文字和結構化對話格式之間進行轉換。

  • 動態對話構建:使用 append()extend() 方法追加和擴充套件對話。

  • 多模型支援:自動檢測不同模型系列(Qwen、DialoGPT、Falcon、DeepSeek 等)的模板。

  • 助手 token 遮蔽:識別為強化學習應用而由助手生成的 token。

  • 工具呼叫支援:在對話中處理函式呼叫和工具響應。

  • 批處理操作:用於同時處理多個對話的高效張量操作。

History(role, content[, is_complete, ...])

ContentBase(type, text, url, data, ...[, ...])

支援的模型系列

我們目前支援以下模型系列進行字串到 History 的解析或助手 token 遮蔽:

  • Qwen 系列(例如,Qwen/Qwen2.5-0.5B):具有完整工具呼叫支援的自定義模板。

  • DialoGPT 系列(例如,microsoft/DialoGPT-medium):用於對話格式的自定義模板。

  • Falcon 系列(例如,tiiuae/falcon-7b-instruct):用於指令格式的自定義模板。

  • DeepSeek 系列(例如,deepseek-ai/deepseek-coder-6.7b-base):具有原生格式的自定義模板。

其他模型也得到支援,但您需要為它們提供自定義模板。LLAMA、Mistral、OPT、GPT、MPT、BLOOM、Pythia、Phi 等將使用預設的 chatml_format 模板。

用法

>>> from torchrl.data.llm.chat import History
>>> from transformers import AutoTokenizer
>>>
>>> # Create a conversation history
>>> history = History.from_chats([[
...     {"role": "user", "content": "Hello"},
...     {"role": "assistant", "content": "Hi there!"},
...     {"role": "user", "content": "How are you?"},
...     {"role": "assistant", "content": "I'm doing well, thanks!"}
... ]])
>>>
>>> # Load any supported tokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
>>>
>>> # Apply chat template with assistant token masking
>>> result = history.apply_chat_template(
...     chat_template_name="qwen",
...     add_generation_prompt=False,
...     return_dict=True,
...     return_assistant_tokens_mask=True,
... )
>>>
>>> # The result contains an assistant_masks tensor
>>> assistant_masks = result["assistant_masks"]
>>> print(f"Assistant tokens: {assistant_masks.sum().item()}")

新增自定義模板

您可以使用 torchrl.data.llm.chat.add_chat_template() 函式為新模型系列新增自定義聊天模板。

用法示例

新增 Llama 模板
>>> from torchrl.data.llm.chat import add_chat_template, History
>>> from transformers import AutoTokenizer
>>>
>>> # Define the Llama chat template
>>> llama_template = '''
... {% for message in messages %}
... {%- if message['role'] == 'user' %}
... {{ '<s>[INST] ' + message['content'] + ' [/INST]' }}
... {%- elif message['role'] == 'assistant' %}
... {% generation %}{{ message['content'] + '</s>' }}{% endgeneration %}
... {%- endif %}
... {% endfor %}
... {%- if add_generation_prompt %}
... {% generation %}{{ ' ' }}{% endgeneration %}
... {%- endif %}
... '''
>>>
>>> # Define the inverse parser for Llama format
>>> def parse_llama_text(text: str) -> History:
...     import re
...     pattern = r'<s>\[INST\]\s*(.*?)\s*\[/INST\]\s*(.*?)</s>'
...     matches = re.findall(pattern, text, re.DOTALL)
...     messages = []
...     for user_content, assistant_content in matches:
...         messages.append(History(role="user", content=user_content.strip()))
...         messages.append(History(role="assistant", content=assistant_content.strip()))
...     return lazy_stack(messages)
>>>
>>> # Add the template with auto-detection
>>> add_chat_template(
...     template_name="llama",
...     template=llama_template,
...     inverse_parser=parse_llama_text,
...     model_family_keywords=["llama", "meta-llama"]
... )
>>>
>>> # Now you can use it with auto-detection
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> history = History.from_chats([[
...     {"role": "user", "content": "Hello"},
...     {"role": "assistant", "content": "Hi there!"}
... ]])
>>>
>>> # Auto-detection will use the llama template
>>> result = history.apply_chat_template(
...     tokenizer=tokenizer,
...     add_generation_prompt=False,
...     return_dict=True,
...     return_assistant_tokens_mask=True,
... )

測試您的自定義模板

在新增自定義模板時,您應該對其進行測試以確保其正常工作。以下是建議的測試:

助手 token 遮蔽測試

測試您的模板是否支援助手 token 遮蔽。

import pytest
from torchrl.data.llm.chat import History, add_chat_template
from transformers import AutoTokenizer

def test_my_model_assistant_masking():
    """Test that your model supports assistant token masking."""
    # Add your template first
    add_chat_template(
        template_name="my_model",
        template="your_template_here",
        model_family_keywords=["my_model"]
    )

    tokenizer = AutoTokenizer.from_pretrained("your/model/name")
    history = History.from_chats([[
        {'role': 'user', 'content': 'Hello'},
        {'role': 'assistant', 'content': 'Hi there!'}
    ]])

    result = history.apply_chat_template(
        tokenizer=tokenizer,
        chat_template_name="my_model",
        add_generation_prompt=False,
        return_dict=True,
        return_assistant_tokens_mask=True,
    )

    # Verify assistant mask is present
    assert 'assistant_masks' in result
    assert result['assistant_masks'].shape[0] == 1, "Should have batch dimension of 1"
    assert result['assistant_masks'].shape[1] > 0, "Should have sequence length > 0"

    # Verify some assistant tokens are masked
    assistant_token_count = result['assistant_masks'].sum().item()
    assert assistant_token_count > 0, "Should have assistant tokens masked"
    print(f"✓ {assistant_token_count} assistant tokens masked")
模板等價性測試

測試您的自定義模板是否產生與模型預設模板相同的輸出(不包括遮蔽)。

def test_my_model_template_equivalence():
    """Test that your template matches the model's default template."""
    tokenizer = AutoTokenizer.from_pretrained("your/model/name")
    history = History.from_chats([[
        {'role': 'user', 'content': 'Hello'},
        {'role': 'assistant', 'content': 'Hi there!'},
        {'role': 'user', 'content': 'How are you?'},
        {'role': 'assistant', 'content': 'I\'m good, thanks!'},
    ]])

    # Get output with model's default template
    try:
        default_out = history.apply_chat_template(
            tokenizer=tokenizer,
            add_generation_prompt=False,
            chat_template=tokenizer.chat_template,
            tokenize=False,
        )
    except Exception as e:
        default_out = None
        print(f"[WARN] Could not get default template: {e}")

    # Get output with your custom template
    custom_out = history.apply_chat_template(
        tokenizer=tokenizer,
        add_generation_prompt=False,
        chat_template_name="my_model",
        tokenize=False,
    )

    if default_out is not None:
        # Normalize whitespace for comparison
        import re
        def norm(s):
            return re.sub(r"\s+", " ", s.strip())

        assert norm(default_out) == norm(custom_out), (
            f"Custom template does not match default!\n"
            f"Default: {default_out}\nCustom: {custom_out}"
        )
        print("✓ Template equivalence verified")
    else:
        print("[INFO] Skipped equivalence check (no default template available)")
反向解析測試

如果您提供了反向解析器,請測試其是否正常工作。

def test_my_model_inverse_parsing():
    """Test that your inverse parser works correctly."""
    history = History.from_chats([[
        {'role': 'user', 'content': 'Hello'},
        {'role': 'assistant', 'content': 'Hi there!'}
    ]])

    # Format using your template
    formatted = history.apply_chat_template(
        tokenizer=tokenizer,
        chat_template_name="my_model",
        add_generation_prompt=False,
        tokenize=False,
    )

    # Parse back using your inverse parser
    parsed = History.from_text(formatted, chat_template_name="my_model")

    # Verify the parsing worked
    assert parsed.role == history.role
    assert parsed.content == history.content
    print("✓ Inverse parsing verified")

LLM 包裝器 API

LLM 包裝器 API 提供了不同 LLM 後端的統一介面,確保了訓練和推理管道之間一致的輸入/輸出格式。主要的包裝器是用於 Hugging Face 模型的 TransformersWrapper 和用於 vLLM 推理的 vLLMWrapper

資料結構類

包裝器使用結構化的 TensorClass 物件來表示 LLM 資料的不同方面:

  • :class:`~torchrl.modules.llm.policies.Text`:包含具有 promptresponsefull 欄位的文字資料。

  • :class:`~torchrl.modules.llm.policies.ChatHistory`:包含具有 promptresponsefull 欄位的 History 物件。

  • :class:`~torchrl.modules.llm.policies.Tokens`:包含具有 promptresponsefull 欄位的 token 化資料。

  • :class:`~torchrl.modules.llm.policies.LogProbs`:包含具有 promptresponsefull 欄位的對數機率。

  • :class:`~torchrl.modules.llm.policies.Masks`:包含注意力掩碼和助手掩碼。

API 流程

包裝器在兩種不同的模式下執行:

生成模式 (`generate=True`):- 輸入:從 prompt 欄位讀取(例如,history.prompttext.prompttokens.prompt)- 輸出:寫入 responsefull 欄位。

  • response:僅包含新生成的內容。

  • full:包含完整的序列(prompt + response)。

對數機率模式 (`generate=False`):- 輸入:從 full 欄位讀取(例如,history.fulltext.fulltokens.full)- 輸出:將對數機率寫入相應的 full 欄位。

LLM-環境互動迴圈

LLM-Environment interaction loop

LLM-環境互動:LLM 生成響應,環境更新對話,轉換可以注入新訊息或工具。

在典型的 RL 或工具增強設定中,LLM 和環境在一個迴圈中互動:

  1. LLM 生成:LLM 包裝器接收 prompt(當前對話歷史),生成 response,並輸出一個 full 欄位。

包含 prompt 和 response 的連線。

  1. 環境步進:環境將 full 欄位作為下一個 prompt 提供給 LLM。這確保了對話

上下文隨著每次輪次而增長。更多詳細資訊請參閱 ref_env_llm_step

  1. 轉換:在下一個 LLM 步進之前,轉換可以修改對話,例如,透過插入新的使用者訊息、工具呼叫

或獎勵註釋。

  1. 重複:此過程重複進行所需的輪次數,從而實現多輪對話、工具使用和 RL 訓練。

這種設計允許在每一步對對話進行靈活的增強,支援高階 RL 和工具使用場景。

典型的虛擬碼迴圈

# Get the first prompt out of an initial query
obs = env.reset(TensorDict({"query": "Hello!"}, batch_size=env.batch_size, device=env.device))
while not done:
    # LLM generates a response given the current prompt
    llm_output = llm(obs)
    # Environment steps: creates a ("next", "history") field with the new prompt (from the previous `"full"` field)
    obs = env.step(llm_output)

與 History 整合

當使用 input_mode="history" 時,包裝器可以與 History 類無縫整合。

  • 輸入:接收包含 prompt 欄位中 History 的 ChatHistory 物件。

  • 生成:應用聊天模板將 History 轉換為 token,生成響應,然後將完整的文字解析回 History 物件。

  • 輸出:返回一個 ChatHistory,其中包含:- prompt:原始對話歷史- response:僅包含助手響應的新 History 物件- full:包含新響應已追加的完整對話歷史。

此設計允許自然的對話流程,其中每個生成步驟都會擴充套件對話歷史,使其成為多輪對話系統的理想選擇。

Prompt 與 Response 及填充

LLM output data format (Tokens, Masks, Padded vs. Sparse)

LLM 輸出結構:Token、LogProbs 和 Mask 的填充與稀疏表示。

上圖說明了 TorchRL LLM API 中使用的主要輸出類的結構:

  • Tokens(以及擴充套件的 LogProbs):- 填充格式:批次中的所有序列都填充到相同的長度(使用特殊的 pad token),使其適合張量操作。prompt 和 response 被連線起來形成 tokens.full,掩碼指示有效位置與填充位置。- 稀疏格式:每個序列保留其原始長度(無填充),表示為張量列表。這對於可變長度資料更節省記憶體。

  • Masks:顯示了兩種主要掩碼:- mask.attention_mask_all 標記有效(非 pad)token。- mask.assistant_mask_all 標記由助手生成的 token(對於 RLHF 和 SFT 訓練有用)。

  • Text:未詳細顯示,因為它只是 prompt、response 或完整序列的解碼字串表示。

此格式確保所有 LLM 輸出(Tokens、LogProbs、Masks、Text)都是一致的且易於操作,無論您使用填充批處理還是稀疏批處理。

總的來說,我們建議使用未填充的資料,因為它更節省記憶體且易於操作。例如,當從緩衝區收集多個填充元素時,很難清楚地理解如何重新填充它們以將它們組合成一個連貫的批次。使用未填充的資料更直接。

模組

LLM 包裝器 API 提供了不同 LLM 後端的統一介面,確保了訓練和推理管道之間一致的輸入/輸出格式。

包裝器

這些原語的主要目標是:

  • 統一訓練和推理管道的輸入/輸出資料格式。

  • 統一後端之間的輸入/輸出資料格式(以便在不同的損失和收集器中使用不同的後端)。

  • 提供適當的工具來在典型的 RL 環境中構建這些物件(資源分配、非同步執行、權重更新等)。

LLMWrapperBase(*args, **kwargs)

LLM 包裝器基類。

TransformersWrapper(*args, **kwargs)

Hugging Face Transformers 模型的包裝器類,為文字生成和對數機率計算提供了一致的介面。

vLLMWrapper(*args, **kwargs)

vLLM 模型的包裝器類,為文字生成和對數機率計算提供了一致的介面。

RemoteTransformersWrapper(model[, ...])

一個用於 TransformersWrapper 的遠端 Ray actor 包裝器,提供了一個簡化的介面。

AsyncVLLM(engine_args[, num_replicas, ...])

一個管理多個非同步 vLLM 引擎 actor 以進行分散式推理的服務。

ChatHistory([prompt, response, full, ...])

Text([prompt, response, full, device, names])

LogProbs([prompt, response, full, padded, ...])

Masks([all_attention_mask, ...])

Tokens([prompt, response, full, padded, ...])

遠端包裝器

TorchRL 提供了遠端包裝器類,可以使用 Ray 實現 LLM 包裝器的分散式執行。這些包裝器提供了一個簡化的介面,不需要顯式的 remote()get() 呼叫,使其易於在分散式環境中使用。

注意

對於 vLLM:請改用 AsyncVLLM

對於基於 vLLM 的推理,我們建議直接使用 AsyncVLLM 而不是遠端包裝器。AsyncVLLM 提供了更好的效能、資源利用率和內建負載均衡。有關詳細資訊,請參閱上面的 非同步 vLLM 引擎(推薦) 部分。

遠端包裝器主要用於基於 Transformers 的模型或其他 AsyncVLLM 不適用的用例。

主要特性

  • 簡化的介面:無需顯式呼叫 remote()get()

  • 完整的 API 相容性:公開了 LLMWrapperBase 基類的所有公共方法。

  • 自動 Ray 管理:內部處理 Ray 初始化和遠端執行。

  • 屬性訪問:所有屬性都可以透過遠端包裝器訪問。

  • 錯誤處理:從遠端 actor 正確傳播錯誤。

  • 資源管理:支援上下文管理器以進行自動清理。

模型引數要求

  • RemoteTransformersWrapper:僅接受字串模型名稱/路徑。Transformers 模型不可序列化。

支援的後端

目前,只有基於 Transformers 的模型透過遠端包裝器支援。對於 vLLM 模型,請改用 AsyncVLLM

用法示例

import ray
from torchrl.modules.llm.policies import RemoteTransformersWrapper
from torchrl.data.llm import History
from torchrl.modules.llm.policies import ChatHistory, Text
from tensordict import TensorDict

# Initialize Ray (if not already done)
if not ray.is_initialized():
    ray.init()

# Transformers wrapper (only string models supported)
# The remote wrappers implement context managers for proper resource cleanup:
with RemoteTransformersWrapper(
    model="gpt2",
    max_concurrency=16,
    input_mode="text",
    generate=True,
    generate_kwargs={"max_new_tokens": 30}
) as remote_transformers:

    text_input = TensorDict({"text": Text(prompt="Hello world")}, batch_size=(1,))
    result = remote_transformers(text_input)
    print(result["text"].response)

效能考慮

  • 網路開銷:遠端執行增加了網路通訊開銷。

  • 序列化:資料在傳送到遠端 actor 時會被序列化。

  • 記憶體:每個遠端 actor 都維護自己的模型副本。

  • 併發:多個遠端包裝器可以併發執行。

  • 最大併發數:使用 max_concurrency 引數控制對每個遠端 actor 的併發呼叫次數。

  • 清理:始終使用上下文管理器或呼叫 cleanup_batching() 以防止因批處理鎖而掛起。

Utils

make_async_vllm_engine(model_name[, ...])

建立非同步 vLLM 引擎服務。

stateless_init_process_group_async(...)

為分散式通訊初始化一個無狀態程序組(非同步版本)。

make_vllm_worker(*, model_name[, devices, ...])

建立具有張量並行支援的 vLLM 推理引擎。

stateless_init_process_group(master_address, ...)

為分散式通訊初始化一個無狀態程序組。

收集器

TorchRL 提供專門的收集器類(LLMCollectorRayLLMCollector),這些類針對 LLM 用例進行了定製。我們還為一些推理引擎提供了專用的更新器。

有關收集器 API 的更多詳細資訊,請參閱 ref_collectors。簡而言之,收集器的想法是將管道的推理部分隔離到一個專用類中。收集器通常以策略和環境為輸入,並在兩者之間交替執行。在“經典”設定中,策略類似於正在訓練的策略(具有一些可選的額外探索)。在 LLM 微調的上下文中,策略通常是一個專業的推理引擎,例如 vLLM 伺服器。收集器由以下引數和功能定義:

  • 同步/非同步:收集器是否應以同步或非同步模式執行。在同步模式下,收集器將在最佳化/訓練步驟之間交替執行推理步驟。在非同步模式下,收集器將與最佳化/訓練步驟並行執行推理步驟。可以向收集器傳遞一個回放緩衝區,這樣收集器就可以直接寫入它。在其他情況下,收集器可以被迭代以收集資料。

  • 步數:收集器構建時具有一定的步數預算,以及在收集期間每次 yield 批次中要包含的步數。

  • 權重更新器:權重更新器是更新策略權重的類。將權重更新隔離到一個專用類中,可以根據策略規範輕鬆實現不同的權重更新策略。

策略版本跟蹤

LLM 收集器還允許跟蹤策略的版本,這對於某些用例很有用。這透過將 PolicyVersion 轉換新增到環境中來實現,然後由收集器在每次權重更新後遞增。為此,可以向收集器建構函式提供轉換的狀態版本或一個布林值。

>>> from torchrl.envs.llm.transforms import PolicyVersion
>>> from torchrl.collectors.llm import LLMCollector
>>> from torchrl.collectors.llm.weight_update import vLLMUpdater
>>> env = make_env() # place your code here
>>> policy = make_policy() # place your code here
>>> collector = LLMCollector(env, policy=policy, weight_updater=vLLMUpdater(), track_policy_version=True)
>>> # init the updater
>>> collector.weight_updater.init(...)
>>> # the version is incremented after each weight update
>>> collector.update_policy_weights_(state_dict=...)
>>> print(collector.policy_version_tracker.version)
>>> # the policy version is written in the data
>>> for data in collector:
...     print(data["policy_version"])

vLLMUpdater(*args[, v2])

一個將權重發送到 vLLM worker 的類。

vLLMUpdaterV2(vllm_engine)

使用 RLvLLMEngine 介面的簡化 vLLM 權重更新器。

LLMCollector(env, *[, policy, ...])

SyncDataCollector 的簡化版本,用於 LLM 推理。

RayLLMCollector(env, *[, policy, ...])

LLM Collector 的輕量級 Ray 實現,可以遠端擴充套件和取樣。

環境

環境層負責資料載入、工具執行、獎勵計算和格式化的編排。當使用 TorchRL 微調 LLM 時,環境是推理管道的關鍵組成部分,與策略和收集器並列。

ChatEnv

ChatEnv 是 LLM 環境的空白畫布——它是一個基本工具,旨在透過新增特定功能的轉換來擴充套件。基礎 ChatEnv 提供了使用 History 格式管理對話狀態的基本結構,但它故意保持最小化以實現最大的靈活性。

核心功能

ChatEnv 在三種主要模式下執行:- History 模式:使用 History 物件進行對話管理。- Text 模式:使用簡單的文字字串進行輸入/輸出。- Tokens 模式:使用 token 化資料進行輸入/輸出。

環境透過以下方式維護對話狀態:- 重置:初始化帶有可選系統 prompt 的新對話。- 步進:接收 LLM 的響應並更新對話歷史,準備下一個 prompt。

基於轉換的架構

轉換是擴充套件 ChatEnv 以實現特定功能的​​主要方式:

與 LLM 包裝器整合

ChatEnv 設計為與 TransformersWrappervLLMWrapper 無縫協同工作。環境負責管理對話狀態,而包裝器負責實際的 LLM 推理,實現了清晰的關注點分離。

在每次呼叫 step 時,環境會:

  • 接收 LLM 的輸出,特別是 full 欄位,其中包含到目前為止的整個對話,包括新響應(例如,history.fulltext.fulltokens.full)。

  • 將此 full 欄位設定為下一個 LLM 步進的 prompt(例如,td[“next”, “history”].prompttd[“next”, “text”].prompttd[“next”, “tokens”].prompt)。

  • 可以選擇應用轉換以在下一個 LLM 步進之前插入新的使用者訊息、工具呼叫或其他對話修改,以最佳化 prompt。

這種機制支援無縫的多輪互動,並支援複雜的用例,如工具使用和獎勵塑形。

特定任務的環境

我們提供了一些特定任務的環境,例如用於 GSM8K 資料集的 GSM8KEnv,用於 IFEval 資料集的 IFEvalEnv,以及用於 MLGym 整合的 MLGymEnv

這些環境包裝了一個 ChatEnv,並在一個 TransformedEnv 類中添加了一個 DataLoadingPrimer 轉換(以及一個可選的獎勵解析轉換)。

ChatEnv(*args, **kwargs)

一個用於 LLM 的聊天環境,設計為一個用於對話和 RL 的空白畫布。

DatasetChatEnv(*args, **kwargs)

用於從資料集中提取查詢的聊天環境的基類。

GSM8KEnv(*args, **kwargs)

GSM8K 資料集環境。

make_gsm8k_env([dataset, num_envs, repeats, ...])

一個基於 LLMEnv 的 GSM8K 環境的構建器。

GSM8KPrepareQuestion([in_keys, out_keys])

在使用 GSM8k 作為 LLMEnv 的一部分時準備 prompt 的轉換。

IFEvalEnv(*args, **kwargs)

基於 IFEval 資料集的聊天環境。

IfEvalScorer(*[, instruction_ids_key, ...])

IF-Eval 任務的評分器。

IFEvalScoreData(prompt_level_strict_acc, ...)

LLMEnv(*args, **kwargs)

用於語言模型的文字生成環境。

LLMHashingEnv(*args, **kwargs)

一個使用雜湊模組來識別唯一觀測值的文字生成環境。

make_mlgym(*[, task, tasks, tokenizer, ...])

將 MLGymEnv 包裝成 TorchRL 環境。

MLGymWrapper(*args, **kwargs)

MLGym 環境的薄包裝器。

GSM8KRewardParser(tokenizer[, in_keys, ...])

用於 GSM8KEnv 或 make_gsm8k_env 的獎勵解析器。

變換 (Transforms)

轉換用於在將資料傳遞給 LLM 之前修改資料。工具通常作為轉換實現,並附加到一個基礎環境,如 ChatEnv

一個工具轉換的例子是 PythonInterpreter 轉換,它用於在 LLM 響應的上下文中執行 Python 程式碼。

>>> from torchrl.envs.llm.transforms import PythonInterpreter
>>> from torchrl.envs.llm import ChatEnv
>>> from tensordict import TensorDict, set_list_to_stack
>>> from transformers import AutoTokenizer
>>> from pprint import pprint
>>> set_list_to_stack(True).set()
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
>>> base_env = ChatEnv(
...     tokenizer=tokenizer,
...     system_prompt="You are an assistant that can execute Python code. Decorate your code with ```python``` tags.",
...     user_role="user",
...     system_role="system",
...     batch_size=[1],
... )
>>> env = base_env.append_transform(PythonInterpreter())
>>> env.set_seed(0)
>>> # Pass the reset data - the prompt - to the environment
>>> reset_data = env.reset(TensorDict(
...     text="Let's write a Python function that returns the square of a number.",
...     batch_size=[1])
... )
>>> # Simulate an action - i.e., a response from the LLM (as if we were an LLM)
>>> action = """Here is a block of code to be executed in python:
... ```python
... def square(x):
...     return x * x
... print('testing the square function with input 2:', square(2))
... ```
... <|im_end|>
... """
>>> step_data = reset_data.set("text_response", [action])
>>> s, s_ = env.step_and_maybe_reset(reset_data)
>>> # The history is a stack of chat messages.
>>> #  The python interpreter transform has executed the code in the last message.
>>> pprint(s_["history"].apply_chat_template(tokenizer=tokenizer))
['<|im_start|>system\n'
 'You are an assistant that can execute Python code. Decorate your code with '
 '```python``` tags.<|im_end|>\n'
 '<|im_start|>user\n'
 "Let's write a Python function that returns the square of a "
 'number.<|im_end|>\n'
 '<|im_start|>assistant\n'
 'Here is a block of code to be executed in python:\n'
 '```python\n'
 'def square(x):\n'
 '    return x * x\n'
 "print('testing the square function with input 2:', square(2))\n"
 '```<|im_end|>\n'
 '<|im_start|>user\n'
 '<tool_response>\n'
 'Code block 1 executed successfully:\n'
 'testing the square function with input 2: 4\n'
 '\n'
 '</tool_response><|im_end|>\n'
 '<|im_start|>assistant\n']

同樣,從資料集中載入資料的環境只是 ChatEnv 的特殊例項,並增加了 DataLoadingPrimer 轉換(以及一些專門的獎勵解析轉換)。

設計獎勵轉換

在為 LLM 環境設計獎勵轉換時,必須考慮幾個關鍵因素,以確保與訓練管道的正確整合。 GSM8KRewardParserIfEvalScorer 的示例為獎勵轉換設計提供了絕佳的模板。

獎勵形狀要求

獎勵張量必須具有與 logits 相同的維度數,這通常比環境批次大小多兩個維度。

  • 稀疏獎勵:形狀 (*bsz, 1, 1) - 每個序列一個獎勵。

  • 密集獎勵:形狀 (*bsz, num_tokens, 1) - 每個 token 一個獎勵。

此形狀要求確保與損失計算管道相容。例如,在 GSM8K 獎勵解析器中:

# Rewards need to have shape broadcastable to [batch x tokens x 1]
tds = tds.apply(lambda t: t.unsqueeze(-1).unsqueeze(-1))

Done 狀態管理

妥善管理 done 狀態對於防止無限生成至關重要。常見策略包括:

  1. 完成為基礎的終止:當響應完成時設定 done(例如,History.complete=True)。

  2. 基於內容的終止:檢測到特定內容時設定 done(例如,<answer> 塊)。

  3. 基於步數的終止:使用 StepCounter 來預設步數限制。

IFEvalScorer 的示例

if self.set_done_if_answer and bool(answer_blocks):
    next_tensordict.set("done", torch.ones(...))
    next_tensordict.set("terminated", torch.ones(...))

輸入模式處理

獎勵轉換必須正確處理不同的輸入模式:

  • History 模式:從 ("history", "full")("history", "response") 中提取文字。

  • Text 模式:直接使用 ("text", "full")("text", "response") 中的文字。

  • Tokens 模式:從 ("tokens", "full")("tokens", "response") 解碼 token。

GSM8K 獎勵解析器演示了此模式。

if input_mode == "history":
    responses = lazy_stack([r[..., -1] for r in responses.unbind(0)])
    if hasattr(responses, "content"):
        text_completion = responses.content
elif input_mode == "text":
    text_completion = responses
elif input_mode == "tokens":
    text_completion = self.tokenizer.decode(responses.flatten(0, 1).tolist())

規範管理

準確指定獎勵和觀察規範對於正確初始化環境至關重要。GSM8K 和 IFEval 都提供了很好的示例。

def transform_reward_spec(self, reward_spec: Composite) -> Composite:
    shape = reward_spec.shape + (1, 1)
    reward_spec.update(
        Composite(
            reward_answer=Unbounded(shape),
            reward_think=Unbounded(shape),
            reward_right=Unbounded(shape),
            reward_contained=Unbounded(shape),
            reward=Unbounded(shape),
            success=Unbounded(shape, dtype=torch.bool),
        )
    )
    return reward_spec

批處理注意事項

為了高效處理,請妥善處理批處理資料:

  1. 展平批次維度:使用 tensordict.view(-1) 進行處理。

  2. 重塑結果:處理後恢復原始批次結構。

  3. 處理可變長度序列:使用適當的填充和遮蔽。

獎勵聚合策略

考慮不同的獎勵聚合方法:

  1. 簡單聚合:對多個獎勵元件求和或取平均。

  2. 加權聚合:對不同元件應用不同權重。

  3. 條件獎勵:基於特定條件或閾值設定獎勵。

IFEvalScorer 演示了複雜的聚合策略。

def default_reward_aggregator(self, score: IFEvalScoreData, ...):
    # Format score (max 1.0)
    format_score = (format_components * weights).sum(dim=-1, keepdim=True)

    # Structure score (max 1.0)
    structure_score = think_score + answer_score

    # Completion bonus (max 0.2)
    completion_bonus = float(complete) * 0.2

    return format_score + structure_score + completion_bonus

回放緩衝區中的後處理

獎勵也可以透過將轉換附加到回放緩衝區來事後計算。但是,done 狀態捕獲必須保留在環境轉換中,因為它需要在資料收集期間即時發生。

錯誤處理和魯棒性

實現魯棒的錯誤處理以應對解析失敗。

try:
    cot, potential_answer = self.extract_tags(compl)
except ET.ParseError:
    cot, potential_answer = ("", "")

效能考慮

  1. 避免冗餘計算:在可能的情況下快取解析結果。

  2. 使用高效的文字處理:根據需要利用正則表示式或 XML 解析。

  3. 最小化記憶體分配:重用張量並避免不必要的複製。

透過遵循這些設計原則,可以將獎勵轉換有效地整合到 LLM 訓練管道中,同時保持效能和可靠性。

AddThinkingPrompt(cond[, prompt, ...])

一個新增思考 prompt 以鼓勵 LLM 重新考慮其響應的轉換。

BrowserTransform([allowed_domains, ...])

一個啟用網頁瀏覽功能的轉換。

DataLoadingPrimer(*args[, use_ray_service])

一個從資料載入器載入資料並使用 stack_method 將其轉換為 tensordict 的 primer。

KLComputation([gen_log_probs_full_key, ...])

一個用於計算兩個對數機率張量之間的 KL 散度,並可選地將其新增到獎勵中的轉換。

KLRewardTransform(*args[, use_ray_service])

用於計算基於 KL 散度的獎勵的舊轉換。

MCPToolTransform(tools, tool_schemas[, ...])

一個在 LLM 操作響應中執行 MCP 風格工具的轉換。

PolicyVersion(version_type, ] =)

一個跟蹤策略版本的轉換。

PythonInterpreter([tokenizer, tool_name, ...])

一個在 LLM 響應中執行 Python 程式碼的轉換。

RayDataLoadingPrimer(*[, dataloader, ...])

一個 DataLoadingPrimer,它建立了一個可以被多個環境共享的單個 actor。

RetrieveKL(*args[, use_ray_service])

一個用於檢索兩個模型對數機率之間 KL 散度的轉換。

RetrieveLogProb(model, *[, ...])

一個用於從模型檢索對數機率以進行 KL 散度計算的轉換。

TemplateTransform(tokenizer[, chat_template])

一個在正向傳播期間對映應用聊天模板到輸入字串,並在反向傳播期間將字串解析回模板的轉換。

Tokenizer([in_keys, out_keys, in_keys_inv, ...])

對指定輸入應用分詞操作。

as_nested_tensor(list_of_tensordicts)

將 tensordict 列表堆疊成具有巢狀張量的單個 tensordict。

as_padded_tensor(list_of_tensordicts[, dim, ...])

將 tensordict 列表堆疊成具有填充張量的單個 tensordict。

目標

LLM 的訓練後需要專門的損失函式,這些函式經過調整以適應語言模型的獨特特性。

GRPO

GRPOLoss 類是 PPOLoss 類的薄包裝器,它封裝了 LLM 特有的功能。

GRPOLoss(*args, **kwargs)

GRPO 損失。

GRPOLossOutput(loss_objective, ...[, ...])

MCAdvantage(grpo_size[, prompt_key, ...])

蒙特卡羅優勢計算引擎。

SFT

SFTLoss(*args, **kwargs)

監督微調損失。

SFTLossOutput(loss_sft[, loss_kl_to_ref, ...])

TopKRewardSelector(total_dialog_turns, topk_size)

一個回放緩衝區轉換,用於為每個 prompt 選擇 top-k 獎勵。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源