PromptTensorDictTokenizer¶
- class torchrl.data.PromptTensorDictTokenizer(tokenizer, max_length, key='prompt', padding='max_length', truncation=True, return_tensordict=True, device=None)[原始碼]¶
用於 prompt 資料集的 tokenization 示例。
返回一個 tokenizer 函式,該函式讀取包含 prompt 和 label 的示例並對其進行 tokenization。
- 引數:
tokenizer (來自 transformers 庫的 tokenizer) – 要使用的 tokenizer。
max_length (int) – 序列的最大長度。
key (str, optional) – 要查詢文字的鍵。預設為
"prompt"。padding (str, optional) – padding 的型別。預設為
"max_length"。truncation (bool, optional) – 序列是否應截斷到 max_length。
return_tensordict (bool, optional) – 如果為
True,則返回 TensoDict。否則,將返回原始資料。device (torch.device, optional) – 用於儲存資料的裝置。如果
return_tensordict=False,則忽略此選項。
此類中的
__call__()方法將執行以下操作:讀取與
label字串連線的prompt字串並對其進行 tokenization。結果將儲存在"input_ids"TensorDict 條目中。使用 prompt 的最後一個有效 token 的索引寫入
"prompt_rindex"條目。寫入
"valid_sample",該條目標識 tensordict 中哪些條目具有足夠的 token 以滿足max_length標準。返回一個帶有 tokenized 輸入的
tensordict.TensorDict例項。
tensordict 的 batch-size 將與輸入的 batch-size 匹配。
示例
>>> from transformers import AutoTokenizer >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> tokenizer.pad_token = tokenizer.eos_token >>> example = { ... "prompt": ["This prompt is long enough to be tokenized.", "this one too!"], ... "label": ["Indeed it is.", 'It might as well be.'], ... } >>> fn = PromptTensorDictTokenizer(tokenizer, 50) >>> print(fn(example)) TensorDict( fields={ attention_mask: Tensor(shape=torch.Size([2, 50]), device=cpu, dtype=torch.int64, is_shared=False), input_ids: Tensor(shape=torch.Size([2, 50]), device=cpu, dtype=torch.int64, is_shared=False), prompt_rindex: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False), valid_sample: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False)