快捷方式

TokenizedDatasetLoader

class torchrl.data.TokenizedDatasetLoader(split, max_length, dataset_name, tokenizer_fn: type[TensorDictTokenizer], pre_tokenization_hook=None, root_dir=None, from_disk=False, valid_size: int = 2000, num_workers: int | None = None, tokenizer_class=None, tokenizer_model_name=None)[原始碼]

載入一個已分詞的資料集,並快取其記憶體對映副本。

引數:
  • split (str) – "train""valid" 之一。

  • max_length (int) – 最大序列長度。

  • dataset_name (str) – 資料集的名稱。

  • tokenizer_fn (callable) – 分詞方法建構函式,例如 torchrl.data.llm.TensorDictTokenizer。呼叫時,它應返回一個 tensordict.TensorDict 例項或一個包含分詞資料的類似字典的結構。

  • pre_tokenization_hook (callable, optional) – 在分詞之前在 Dataset 上呼叫。它應返回一個修改後的 Dataset 物件。預期用途是執行需要修改整個 Dataset 的任務,而不是修改單個數據點,例如根據特定條件丟棄某些資料點。資料上的分詞和其他“逐元素”操作由對映到 Dataset 的 process 函式執行。

  • root_dir (path, optional) – 儲存資料集的路徑。預設為 "$HOME/.cache/torchrl/data"

  • from_disk (bool, optional) – 如果為 True,將使用 datasets.load_from_disk()。否則,將使用 datasets.load_dataset()。預設為 False

  • valid_size (int, optional) – 驗證資料集的大小(如果 split 以 "valid" 開頭)將被截斷到此值。預設為 2000 個專案。

  • num_workers (int, optional) – 在分詞過程中呼叫的 datasets.dataset.map() 的工作執行緒數。預設為 max(os.cpu_count() // 2, 1)

  • tokenizer_class (Type, optional) – 分詞器類,例如 AutoTokenizer (預設)。

  • tokenizer_model_name (str, optional) – 應從中收集詞彙表的模型。預設為 "gpt2"

資料集將儲存在 <root_dir>/<split>/<max_length>/ 中。

示例

>>> from torchrl.data.llm import TensorDictTokenizer
>>> from torchrl.data.llm.reward import  pre_tokenization_hook
>>> split = "train"
>>> max_length = 550
>>> dataset_name = "CarperAI/openai_summarize_comparisons"
>>> loader = TokenizedDatasetLoader(
...     split,
...     max_length,
...     dataset_name,
...     TensorDictTokenizer,
...     pre_tokenization_hook=pre_tokenization_hook,
... )
>>> dataset = loader.load()
>>> print(dataset)
TensorDict(
    fields={
        attention_mask: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False),
        input_ids: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([185068]),
    device=None,
    is_shared=False)
static dataset_to_tensordict(dataset: datasets.Dataset | TensorDict, data_dir: Path, prefix: NestedKey = None, features: Sequence[str] = None, batch_dims=1, valid_mask_key=None)[原始碼]

將資料集轉換為記憶體對映的 TensorDict。

如果資料集已經是 TensorDict 例項,它將被簡單地轉換為記憶體對映的 TensorDict。否則,資料集應具有 features 屬性,該屬性是一個字串序列,指示可以在資料集中找到的特徵。如果不存在,則必須將 features 顯式傳遞給此函式。

引數:
  • dataset (datasets.Dataset, TensorDict等價物) – 要轉換為記憶體對映 TensorDict 的資料集。如果 featuresNone,則它必須具有一個 features 屬性,其中包含要寫入 tensordict 的鍵列表。

  • data_dir (Path等價物) – 應將資料寫入的目錄。

  • prefix (NestedKey, optional) – 資料集位置的`prefix`。這可用於區分同一資料集的不同副本,這些副本已進行了不同的預處理。

  • features (str 的序列, optional) – 一個字串序列,指示可以在資料集中找到的特徵。

  • batch_dims (int, optional) – 資料的 `batch_dimensions` 數量(即 tensordict 可以索引的維度數)。預設為 1。

  • valid_mask_key (NestedKey, optional) – 如果提供,將嘗試收集此條目並用於過濾資料。預設為 None(即,無過濾鍵)。

返回: 一個包含資料集的記憶體對映張量的 TensorDict。

示例

>>> from datasets import Dataset
>>> import tempfile
>>> data = Dataset.from_dict({"tokens": torch.randint(20, (10, 11)), "labels": torch.zeros(10, 11)})
>>> with tempfile.TemporaryDirectory() as tmpdir:
...     data_memmap = TokenizedDatasetLoader.dataset_to_tensordict(
...         data, data_dir=tmpdir, prefix=("some", "prefix"), features=["tokens", "labels"]
...     )
...     print(data_memmap)
TensorDict(
    fields={
        some: TensorDict(
            fields={
                prefix: TensorDict(
                    fields={
                        labels: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.float32, is_shared=False),
                        tokens: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.int64, is_shared=False)},
                    batch_size=torch.Size([10]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
load()[原始碼]

載入預處理的記憶體對映資料集(如果存在),否則建立它。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源