評價此頁

torch.utils.data#

創建於: 2025年6月13日 | 最後更新於: 2025年6月13日

PyTorch 資料載入工具的核心是 torch.utils.data.DataLoader 類。它表示一個數據集上的 Python 可迭代物件,支援:

這些選項透過 DataLoader 的建構函式引數進行配置,其簽名如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

下面的部分將詳細介紹這些選項的效果和用法。

資料集型別#

DataLoader 建構函式最重要的引數是 dataset,它指定了要從中載入資料的資料集物件。PyTorch 支援兩種不同型別的資料集:

對映式資料集#

對映式資料集是實現了 __getitem__()__len__() 協議的資料集,它表示從索引/鍵(可能是非整數)到資料樣本的對映。

例如,這樣的資料集在透過 dataset[idx] 訪問時,可以從磁碟上的資料夾中讀取第 idx 張影像及其對應的標籤。

更多詳情請參見 Dataset

可迭代式資料集#

可迭代式資料集是 IterableDataset 的子類例項,它實現了 __iter__() 協議,並表示資料樣本的可迭代物件。這種型別的資料集特別適用於隨機讀取成本高或不太可能發生,並且批次大小取決於所獲取資料的情況。

例如,這樣的資料集在呼叫 iter(dataset) 時,可以返回一個從資料庫、遠端伺服器甚至即時生成的日誌中讀取資料的資料流。

更多詳情請參見 IterableDataset

注意

當將 IterableDataset多程序資料載入 一起使用時。同一個資料集物件會在每個工作程序中複製,因此必須對這些副本進行不同的配置以避免重複資料。有關如何實現這一點,請參閱 IterableDataset 文件。

資料載入順序與 Sampler#

對於 可迭代式資料集,資料載入順序完全由使用者定義的迭代器控制。這使得實現分塊讀取和動態批次大小(例如,每次產生一個批次樣本)更加容易。

本節其餘部分將討論 對映式資料集 的情況。 torch.utils.data.Sampler 類用於指定資料載入時使用的索引/鍵的序列。它們表示資料集索引的可迭代物件。例如,在隨機梯度下降(SGD)的常見情況下,一個 Sampler 可以隨機排列索引列表並逐個產生,或者為 mini-batch SGD 產生少量索引。

一個順序或隨機的取樣器將根據 DataLoadershuffle 引數自動構建。或者,使用者可以使用 sampler 引數指定一個自定義的 Sampler 物件,該物件每次產生要獲取的下一個索引/鍵。

一個每次產生一批索引的自定義 Sampler 可以作為 batch_sampler 引數傳遞。還可以透過 batch_sizedrop_last 引數啟用自動批處理。有關這方面的更多資訊,請參見 下一節

注意

Neither sampler nor batch_sampler is compatible with iterable-style datasets, since such datasets have no notion of a key or an index.Neither sampler nor batch_sampler 與可迭代式資料集不相容,因為這類資料集沒有鍵或索引的概念。

載入批次和非批次資料#

DataLoader 支援透過 batch_sizedrop_lastbatch_samplercollate_fn(它有一個預設函式)引數自動將單個獲取的資料樣本合併成批次。

自動批處理(預設)#

這是最常見的情況,對應於獲取一個數據 minibatch 並將其合併成批次樣本,即包含一個批次維度(通常是第一個)的張量(Tensors)。

batch_size(預設為 1)不為 None 時,資料載入器會產生批次樣本而不是單個樣本。batch_sizedrop_last 引數用於指定資料載入器如何獲取資料集鍵的批次。對於對映式資料集,使用者還可以選擇指定 batch_sampler,它一次產生一個鍵列表。

注意

The batch_size and drop_last arguments essentially are used to construct a batch_sampler from sampler. For map-style datasets, the sampler is either provided by user or constructed based on the shuffle argument. For iterable-style datasets, the sampler is a dummy infinite one. See this section on more details on samplers.The batch_sizedrop_last 引數實際上用於從 sampler 構建一個 batch_sampler。對於對映式資料集,sampler 要麼由使用者提供,要麼基於 shuffle 引數構建。對於可迭代式資料集,sampler 是一個假的無限取樣器。有關取樣器的更多詳細資訊,請參閱 本節

注意

When fetching from iterable-style datasets with multi-processing the drop_last argument drops the last non-full batch of each worker’s dataset replica.當使用 多程序可迭代式資料集 中獲取資料時,drop_last 引數會丟棄每個工作程序資料集副本的最後一個非滿批次。

After fetching a list of samples using the indices from sampler, the function passed as the collate_fn argument is used to collate lists of samples into batches.在從取樣器獲取索引列表後,透過 collate_fn 引數傳遞的函式用於將樣本列表合併成批次。

In this case, loading from a map-style dataset is roughly equivalent with在這種情況,從對映式資料集載入大致等同於

for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

and loading from an iterable-style dataset is roughly equivalent with並從可迭代式資料集載入大致等同於

dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

A custom collate_fn can be used to customize collation, e.g., padding sequential data to max length of a batch. See this section on more about collate_fn.可以使用自定義的 collate_fn 來定製合併,例如,將序列資料填充到批次的最大長度。有關 collate_fn 的更多資訊,請參見 本節

停用自動批處理#

In certain cases, users may want to handle batching manually in dataset code, or simply load individual samples. For example, it could be cheaper to directly load batched data (e.g., bulk reads from a database or reading continuous chunks of memory), or the batch size is data dependent, or the program is designed to work on individual samples. Under these scenarios, it’s likely better to not use automatic batching (where collate_fn is used to collate the samples), but let the data loader directly return each member of the dataset object.在某些情況下,使用者可能希望在資料集程式碼中手動處理批處理,或者僅僅載入單個樣本。例如,直接載入批次資料可能更便宜(例如,從資料庫批次讀取或讀取連續記憶體塊),或者批次大小依賴於資料,或者程式設計為處理單個樣本。在這些場景下,最好不要使用自動批處理(其中 collate_fn 用於合併樣本),而是讓資料載入器直接返回 dataset 物件的每個成員。

When both batch_size and batch_sampler are None (default value for batch_sampler is already None), automatic batching is disabled. Each sample obtained from the dataset is processed with the function passed as the collate_fn argument.當 batch_sizebatch_sampler 都為 None 時(batch_sampler 的預設值已經是 None),自動批處理會被停用。從 dataset 獲取的每個樣本都會透過作為 collate_fn 引數傳遞的函式進行處理。

When automatic batching is disabled, the default collate_fn simply converts NumPy arrays into PyTorch Tensors, and keeps everything else untouched.停用自動批處理時,預設的 collate_fn 僅將 NumPy 陣列轉換為 PyTorch 張量,並保持其他所有內容不變。

In this case, loading from a map-style dataset is roughly equivalent with在這種情況,從對映式資料集載入大致等同於

for index in sampler:
    yield collate_fn(dataset[index])

and loading from an iterable-style dataset is roughly equivalent with並從可迭代式資料集載入大致等同於

for data in iter(dataset):
    yield collate_fn(data)

See this section on more about collate_fn.有關 collate_fn 的更多資訊,請參見 本節

使用 collate_fn#

When automatic batching is enabled or disabled, the usage of collate_fn is slightly different.當啟用或停用自動批處理時,collate_fn 的用法略有不同。

When automatic batching is disabled, collate_fn is called with each individual data sample, and the output is yielded from the data loader iterator. In this case, the default collate_fn simply converts NumPy arrays in PyTorch tensors.停用自動批處理時collate_fn 會與每個單獨的資料樣本一起呼叫,然後輸出從資料載入器迭代器中產生。在這種情況下,預設的 collate_fn 僅將 NumPy 陣列轉換為 PyTorch 張量。

When automatic batching is enabled, collate_fn is called with a list of data samples at each time. It is expected to collate the input samples into a batch for yielding from the data loader iterator. The rest of this section describes the behavior of the default collate_fn (default_collate()).啟用自動批處理時collate_fn 會一次性接收一個數據樣本列表。它需要將輸入的樣本合併成一個批次,以便從資料載入器迭代器中產生。本節其餘部分將描述預設 collate_fndefault_collate())的行為。

For instance, if each data sample consists of a 3-channel image and an integral class label, i.e., each element of the dataset returns a tuple (image, class_index), the default collate_fn collates a list of such tuples into a single tuple of a batched image tensor and a batched class label Tensor. In particular, the default collate_fn has the following properties例如,如果每個資料樣本由一個 3 通道影像和一個整數類標籤組成,即資料集的每個元素返回一個元組 (image, class_index),那麼預設的 collate_fn 將這樣一個元組列表合併成一個包含批次影像張量和批次類標籤張量的單個元組。特別是,預設的 collate_fn 具有以下特性:

  • It always prepends a new dimension as the batch dimension.它總是新增一個新的維度作為批次維度。

  • It automatically converts NumPy arrays and Python numerical values into PyTorch Tensors.它自動將 NumPy 陣列和 Python 數值轉換為 PyTorch 張量。

  • It preserves the data structure, e.g., if each sample is a dictionary, it outputs a dictionary with the same set of keys but batched Tensors as values (or lists if the values can not be converted into Tensors). Same for list s, tuple s, namedtuple s, etc.它保留資料結構,例如,如果每個樣本是字典,它會輸出一個具有相同鍵的字典,但值為批次的張量(如果值無法轉換為張量,則為列表)。對於 listtuplenamedtuple 等也是如此。

Users may use customized collate_fn to achieve custom batching, e.g., collating along a dimension other than the first, padding sequences of various lengths, or adding support for custom data types.使用者可以使用自定義的 collate_fn 來實現自定義批處理,例如,沿第一個維度以外的維度進行合併,填充各種長度的序列,或新增對自定義資料型別的支援。

If you run into a situation where the outputs of DataLoader have dimensions or type that is different from your expectation, you may want to check your collate_fn.如果你遇到 DataLoader 的輸出維度或型別與你的預期不同,你可能需要檢查你的 collate_fn

單程序和多程序資料載入#

A DataLoader uses single-process data loading by default.預設情況下,DataLoader 使用單程序資料載入。

Within a Python process, the Global Interpreter Lock (GIL) prevents true fully parallelizing Python code across threads. To avoid blocking computation code with data loading, PyTorch provides an easy switch to perform multi-process data loading by simply setting the argument num_workers to a positive integer.在 Python 程序內部,全域性直譯器鎖 (GIL) 阻止了跨執行緒對 Python 程式碼進行真正完全的並行化。為了避免資料載入阻塞計算程式碼,PyTorch 提供了一個簡單的切換方法,只需將 num_workers 引數設定為正整數即可執行多程序資料載入。

單程序資料載入(預設)#

In this mode, data fetching is done in the same process a DataLoader is initialized. Therefore, data loading may block computing. However, this mode may be preferred when resource(s) used for sharing data among processes (e.g., shared memory, file descriptors) is limited, or when the entire dataset is small and can be loaded entirely in memory. Additionally, single-process loading often shows more readable error traces and thus is useful for debugging.在此模式下,資料獲取在 DataLoader 初始化時所在的同一程序中進行。因此,資料載入可能會阻塞計算。但是,當用於程序間共享資料的資源(例如共享記憶體、檔案描述符)有限時,或者當整個資料集很小並且可以完全載入到記憶體中時,可能更傾向於使用此模式。此外,單程序載入通常會顯示更易讀的錯誤跟蹤,因此有助於除錯。

多程序資料載入#

Setting the argument num_workers as a positive integer will turn on multi-process data loading with the specified number of loader worker processes.將 num_workers 引數設定為正整數將啟用多程序資料載入,並使用指定數量的載入器工作程序。

警告

After several iterations, the loader worker processes will consume the same amount of CPU memory as the parent process for all Python objects in the parent process which are accessed from the worker processes. This can be problematic if the Dataset contains a lot of data (e.g., you are loading a very large list of filenames at Dataset construction time) and/or you are using a lot of workers (overall memory usage is number of workers * size of parent process). The simplest workaround is to replace Python objects with non-refcounted representations such as Pandas, Numpy or PyArrow objects. Check out issue #13246 for more details on why this occurs and example code for how to workaround these problems.經過幾次迭代後,載入器工作程序將消耗與父程序相同的 CPU 記憶體,用於父程序中被工作程序訪問的所有 Python 物件。如果 Dataset 包含大量資料(例如,在 Dataset 構建時載入非常大的檔名列表)和/或您使用了大量工作程序(總記憶體使用量為 工作程序數 * 父程序大小),這可能會出現問題。最簡單的解決方法是用非引用計數的表示形式替換 Python 物件,例如 Pandas、Numpy 或 PyArrow 物件。請檢視 issue #13246 以獲取更多關於此問題發生原因的詳細資訊以及如何解決這些問題的示例程式碼。

In this mode, each time an iterator of a DataLoader is created (e.g., when you call enumerate(dataloader)), num_workers worker processes are created. At this point, the dataset, collate_fn, and worker_init_fn are passed to each worker, where they are used to initialize, and fetch data. This means that dataset access together with its internal IO, transforms (including collate_fn) runs in the worker process.在此模式下,每次建立 DataLoader 的迭代器時(例如,當您呼叫 enumerate(dataloader) 時),會建立 num_workers 個工作程序。此時,datasetcollate_fnworker_init_fn 會被傳遞給每個工作程序,並在那裡用於初始化和獲取資料。這意味著資料集訪問及其內部 IO、轉換(包括 collate_fn)都在工作程序中執行。

torch.utils.data.get_worker_info() 返回工作程序中的各種有用資訊(包括工作程序 ID、資料集副本、初始種子等),在主程序中返回 None。使用者可以在資料集程式碼和/或 worker_init_fn 中使用此函式來單獨配置每個資料集副本,並確定程式碼是否正在工作程序中執行。例如,這在分片資料集時特別有用。

對於對映式資料集,主程序使用 sampler 生成索引,並將它們傳送給工作程序。因此,任何洗牌隨機化都在主程序中完成,主程序透過分配要載入的索引來指導載入。

對於可迭代式資料集,由於每個工作程序都獲得資料集物件的副本,因此簡單的多程序載入通常會導致資料重複。使用 torch.utils.data.get_worker_info() 和/或 worker_init_fn,使用者可以獨立配置每個副本。(請參閱 IterableDataset 文件以瞭解如何實現此目的。)出於類似的原因,在多程序載入中,drop_last 引數會丟棄每個工作程序的可迭代式資料集副本的最後一個非滿批次。

一旦達到迭代的末尾,或者迭代器被垃圾回收,工作程序將被關閉。

警告

It is generally not recommended to return CUDA tensors in multi-process loading because of many subtleties in using CUDA and sharing CUDA tensors in multiprocessing (see CUDA in multiprocessing). Instead, we recommend using automatic memory pinning (i.e., setting pin_memory=True), which enables fast data transfer to CUDA-enabled GPUs.在多程序載入中,通常不建議返回 CUDA 張量,因為在多程序中使用 CUDA 和共享 CUDA 張量存在許多細微之處(請參閱 多程序中的 CUDA)。我們建議改用 自動記憶體固定(即設定 pin_memory=True),這可以實現到啟用 CUDA 的 GPU 的快速資料傳輸。

平臺特定行為#

Since workers rely on Python multiprocessing, worker launch behavior is different on Windows compared to Unix.由於工作程序依賴於 Python 的 multiprocessing 模組,因此與 Unix 相比,Windows 上的工作程序啟動行為有所不同。

  • On Unix, fork() is the default multiprocessing start method. Using fork(), child workers typically can access the dataset and Python argument functions directly through the cloned address space.在 Unix 系統上,fork() 是預設的 multiprocessing 啟動方法。使用 fork(),子工作程序通常可以透過克隆的地址空間直接訪問 dataset 和 Python 引數函式。

  • On Windows or MacOS, spawn() is the default multiprocessing start method. Using spawn(), another interpreter is launched which runs your main script, followed by the internal worker function that receives the dataset, collate_fn and other arguments through pickle serialization.在 Windows 或 MacOS 上,spawn() 是預設的 multiprocessing 啟動方法。使用 spawn(),會啟動另一個直譯器來執行您的主指令碼,然後是內部工作程序函式,該函式透過 pickle 序列化接收 datasetcollate_fn 和其他引數。

This separate serialization means that you should take two steps to ensure you are compatible with Windows while using multi-process data loading這種單獨的序列化意味著您需要採取兩個步驟來確保在使用多程序資料載入時與 Windows 相容:

  • Wrap most of you main script’s code within if __name__ == '__main__': block, to make sure it doesn’t run again (most likely generating error) when each worker process is launched. You can place your dataset and DataLoader instance creation logic here, as it doesn’t need to be re-executed in workers.將您主指令碼的大部分程式碼包裝在 if __name__ == '__main__': 塊中,以確保在啟動每個工作程序時不會再次執行(很可能導致錯誤)。您可以在此處放置資料集和 DataLoader 例項建立邏輯,因為這些邏輯不需要在工作程序中重新執行。

  • Make sure that any custom collate_fn, worker_init_fn or dataset code is declared as top level definitions, outside of the __main__ check. This ensures that they are available in worker processes. (this is needed since functions are pickled as references only, not bytecode.)確保任何自定義的 collate_fnworker_init_fndataset 程式碼在 __main__ 檢查之外被宣告為頂級定義。這確保它們在工作程序中可用。(這是必需的,因為函式僅作為引用被 pickle,而不是作為 bytecode。)

多程序資料載入中的隨機性#

By default, each worker will have its PyTorch seed set to base_seed + worker_id, where base_seed is a long generated by main process using its RNG (thereby, consuming a RNG state mandatorily) or a specified generator. However, seeds for other libraries may be duplicated upon initializing workers, causing each worker to return identical random numbers. (See this section in FAQ.).預設情況下,每個工作程序的 PyTorch 種子將設定為 base_seed + worker_id,其中 base_seed 是由主程序使用其 RNG 生成的長整數(從而強制消耗一個 RNG 狀態)或指定的 generator。然而,其他庫的種子在初始化工作程序時可能會重複,導致每個工作程序返回相同的隨機數。(請參閱 FAQ 中的 本節)。

In worker_init_fn, you may access the PyTorch seed set for each worker with either torch.utils.data.get_worker_info().seed or torch.initial_seed(), and use it to seed other libraries before data loading.在 worker_init_fn 中,您可以透過 torch.utils.data.get_worker_info().seedtorch.initial_seed() 訪問為每個工作程序設定的 PyTorch 種子,並使用它在資料載入之前為其他庫設定種子。

記憶體固定#

Host to GPU copies are much faster when they originate from pinned (page-locked) memory. See Use pinned memory buffers for more details on when and how to use pinned memory generally.當主機到 GPU 的複製源自固定(分頁鎖定)記憶體時,速度會快得多。有關何時以及如何使用固定記憶體的更多詳細資訊,請參閱 使用固定記憶體緩衝區

For data loading, passing pin_memory=True to a DataLoader will automatically put the fetched data Tensors in pinned memory, and thus enables faster data transfer to CUDA-enabled GPUs.對於資料載入,將 pin_memory=True 傳遞給 DataLoader 將自動將獲取的資料張量放入固定記憶體中,從而實現到啟用 CUDA 的 GPU 的快速資料傳輸。

The default memory pinning logic only recognizes Tensors and maps and iterables containing Tensors. By default, if the pinning logic sees a batch that is a custom type (which will occur if you have a collate_fn that returns a custom batch type), or if each element of your batch is a custom type, the pinning logic will not recognize them, and it will return that batch (or those elements) without pinning the memory. To enable memory pinning for custom batch or data type(s), define a pin_memory() method on your custom type(s).預設的記憶體固定邏輯只識別包含張量的張量、對映和可迭代物件。預設情況下,如果固定邏輯看到一個自定義型別的批次(當您有一個返回自定義批次型別的 collate_fn 時會發生這種情況),或者如果您的批次中的每個元素都是自定義型別,那麼固定邏輯將無法識別它們,並且會返回該批次(或那些元素)而不固定記憶體。要為自定義批次或資料型別啟用記憶體固定,請在您的自定義型別上定義一個 pin_memory() 方法。

See the example below.請參閱下面的示例。

示例

class SimpleCustomBatch:
    def __init__(self, data):
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)

    # custom memory pinning method on custom type
    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self

def collate_wrapper(batch):
    return SimpleCustomBatch(batch)

inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                    pin_memory=True)

for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='', in_order=True)[source]#

Data loader combines a dataset and a sampler, and provides an iterable over the given dataset.資料載入器結合了資料集和取樣器,並提供給定資料集的可迭代物件。

The DataLoader supports both map-style and iterable-style datasets with single- or multi-process loading, customizing loading order and optional automatic batching (collation) and memory pinning. DataLoader 支援對映式和可迭代式資料集,支援單程序或多程序載入、自定義載入順序以及可選的自動批處理(合併)和記憶體固定。

See torch.utils.data documentation page for more details.有關更多詳細資訊,請參閱 torch.utils.data 文件頁面。

引數
  • dataset (Dataset) – dataset from which to load the data. **dataset**(Dataset)– 要從中載入資料的的資料集。

  • batch_size (int, optional) – how many samples per batch to load (default: 1). **batch_size**(int, optional)– 每個批次要載入的樣本數(預設為 1)。

  • shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False). **shuffle**(bool, optional)– 設定為 True 可在每個 epoch 中重新洗牌資料(預設為 False)。

  • sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, shuffle must not be specified. **sampler**(SamplerIterable, optional)– 定義從資料集中抽取樣本的策略。可以是任何實現了 __len__Iterable。如果指定了,則 shuffle 必須不被指定。

  • batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last. **batch_sampler**(SamplerIterable, optional)– 類似於 sampler,但一次返回一個索引批次。與 batch_sizeshufflesamplerdrop_last 互斥。

  • num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0) **num_workers**(int, optional)– 用於資料載入的子程序數量。 0 表示資料將在主程序中載入。(預設為 0

  • collate_fn (Callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset. **collate_fn**(Callable, optional)– 將樣本列表合併以形成張量(Tensors)的 mini-batch。當從對映式資料集進行批處理載入時使用。

  • pin_memory (bool, optional) – If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below. **pin_memory**(bool, optional)– 如果為 True,資料載入器將在返回張量之前將其複製到裝置/CUDA 固定記憶體中。如果您的資料元素是自定義型別,或者您的 collate_fn 返回一個自定義型別的批次,請參閱下面的示例。

  • drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False) **drop_last**(bool, optional)– 設定為 True 以丟棄最後一個不完整的批次,如果資料集大小不能被批次大小整除。如果為 False 且資料集大小不能被批次大小整除,則最後一個批次會更小。(預設為 False

  • timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0) **timeout**(numeric, optional)– 如果為正數,則為從工作程序收集批次的超時值。應始終為非負數。(預設為 0

  • worker_init_fn (Callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None) **worker_init_fn**(Callable, optional)– 如果不為 None,則將在每個工作子程序中呼叫此函式,並將工作程序 ID([0, num_workers - 1] 中的整數)作為輸入,在設定種子之後、載入資料之前呼叫。(預設為 None

  • multiprocessing_context (str or multiprocessing.context.BaseContext, optional) – If None, the default multiprocessing context # noqa: D401 of your operating system will be used. (default: None) **multiprocessing_context**(strmultiprocessing.context.BaseContext, optional)– 如果為 None,則將使用您作業系統預設的 多程序上下文 # noqa: D401。(預設為 None

  • generator (torch.Generator, optional) – If not None, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generate base_seed for workers. (default: None) **generator**(torch.Generator, optional)– 如果不為 None,則 RandomSampler 將使用此 RNG 來生成隨機索引,並由多程序為工作程序生成 base_seed。(預設為 None

  • prefetch_factor (int, optional, keyword-only arg) – Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches prefetched across all workers. (default value depends on the set value for num_workers. If value of num_workers=0 default is None. Otherwise, if value of num_workers > 0 default is 2). **prefetch_factor**(int, optional, keyword-only arg)– 每個工作程序預載入的批次數。 2 表示所有工作程序總共預載入 2 * num_workers 個批次。(預設值取決於為 num_workers 設定的值。如果 num_workers=0,則預設值為 None。否則,如果 num_workers > 0,則預設值為 2)。

  • persistent_workers (bool, optional) – If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False) **persistent_workers**(bool, optional)– 如果為 True,資料載入器不會在資料集被消耗一次後關閉工作程序。這允許保持工作程序的 Dataset 例項存活。(預設為 False

  • pin_memory_device (str, optional) – Deprecated, the current accelerator will be used as the device if pin_memory=True. **pin_memory_device**(str, optional)– 已棄用,如果 pin_memory=True,則將使用當前 加速器 作為裝置。

  • in_order (bool, optional) – If False, the data loader will not enforce that batches are returned in a first-in, first-out order. Only applies when num_workers > 0. (default: True) **in_order**(bool, optional)– 如果為 False,資料載入器將不強制按先入先出的順序返回批次。僅在 num_workers > 0 時適用。(預設為 True

警告

If the spawn start method is used, worker_init_fn cannot be an unpicklable object, e.g., a lambda function. See Multiprocessing best practices on more details related to multiprocessing in PyTorch.如果使用 spawn 啟動方法,則 worker_init_fn 不能是不可 picklable 的物件,例如 lambda 函式。有關 PyTorch 中多程序的更多詳細資訊,請參閱 多程序最佳實踐

警告

len(dataloader) heuristic is based on the length of the sampler used. When dataset is an IterableDataset, it instead returns an estimate based on len(dataset) / batch_size, with proper rounding depending on drop_last, regardless of multi-process loading configurations. This represents the best guess PyTorch can make because PyTorch trusts user dataset code in correctly handling multi-process loading to avoid duplicate data. len(dataloader) 的啟發式方法基於所使用的取樣器的長度。當 datasetIterableDataset 時,它將根據 len(dataset) / batch_size 返回一個估計值,並根據 drop_last 進行適當舍入,而忽略多程序載入配置。這代表了 PyTorch 可以做出的最佳猜測,因為 PyTorch 相信使用者 dataset 程式碼能夠正確處理多程序載入以避免資料重複。

However, if sharding results in multiple workers having incomplete last batches, this estimate can still be inaccurate, because (1) an otherwise complete batch can be broken into multiple ones and (2) more than one batch worth of samples can be dropped when drop_last is set. Unfortunately, PyTorch can not detect such cases in general.然而,如果分片導致多個工作程序擁有不完整的最後一個批次,這個估計值仍然可能不準確,因為(1)一個原本完整的批次可能會被分成多個批次,並且(2)當設定 drop_last 時,可能會丟棄一個批次以上樣本。不幸的是,PyTorch 通常無法檢測到這種情況。

See Dataset Types for more details on these two types of datasets and how IterableDataset interacts with Multi-process data loading.有關這兩種資料集型別的更多詳細資訊,以及 IterableDataset 如何與 多程序資料載入 互動,請參閱 資料集型別

警告

Setting in_order to False can harm reproducibility and may lead to a skewed data distribution being fed to the trainer in cases with imbalanced data.將 in_order 設定為 False 可能會損害可重現性,並在資料不平衡的情況下導致訓練器接收到有偏斜的資料分佈。

class torch.utils.data.Dataset[source]#

An abstract class representing a Dataset.一個表示 Dataset 的抽象類。

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__(), which is expected to return the size of the dataset by many Sampler implementations and the default options of DataLoader. Subclasses could also optionally implement __getitems__(), for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples.表示從鍵到資料樣本的對映的所有資料集都應繼承此類。所有子類都應覆蓋 __getitem__(),以支援按給定鍵獲取資料樣本。子類還可以選擇覆蓋 __len__(),許多 Sampler 實現和 DataLoader 的預設選項都期望它返回資料集的大小。子類還可以選擇實現 __getitems__(),以加速批次樣本的載入。此方法接受一批樣本的索引列表並返回樣本列表。

注意

DataLoader by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.預設情況下,DataLoader 構建一個生成整數索引的索引取樣器。為了使其與具有非整數索引/鍵的對映式資料集一起工作,必須提供自定義取樣器。

class torch.utils.data.IterableDataset[source]#

An iterable Dataset.一個可迭代的資料集。

All datasets that represent an iterable of data samples should subclass it. Such form of datasets is particularly useful when data come from a stream.表示資料樣本可迭代物件的所有資料集都應繼承此類。這種形式的資料集在資料來自流時特別有用。

All subclasses should overwrite __iter__(), which would return an iterator of samples in this dataset.所有子類都應覆蓋 __iter__(),它將返回此資料集中樣本的迭代器。

When a subclass is used with DataLoader, each item in the dataset will be yielded from the DataLoader iterator. When num_workers > 0, each worker process will have a different copy of the dataset object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers. get_worker_info(), when called in a worker process, returns information about the worker. It can be used in either the dataset’s __iter__() method or the DataLoader ‘s worker_init_fn option to modify each copy’s behavior.當子類與 DataLoader 一起使用時,資料集中的每個項都將從 DataLoader 迭代器中產生。當 num_workers > 0 時,每個工作程序將擁有資料集物件的不同副本,因此通常希望獨立配置每個副本以避免從工作程序返回重複資料。get_worker_info() 在工作程序中呼叫時,返回有關工作程序的資訊。它可以在資料集的 __iter__() 方法或 DataLoaderworker_init_fn 選項中使用,以修改每個副本的行為。

Example 1: splitting workload across all workers in __iter__()示例 1:在 __iter__() 中跨所有工作程序分配工作負載

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         worker_info = torch.utils.data.get_worker_info()
...         if worker_info is None:  # single-process data loading, return the full iterator
...             iter_start = self.start
...             iter_end = self.end
...         else:  # in a worker process
...             # split workload
...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
...             worker_id = worker_info.id
...             iter_start = self.start + worker_id * per_worker
...             iter_end = min(iter_start + per_worker, self.end)
...         return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[tensor([3]), tensor([4]), tensor([5]), tensor([6])]

>>> # Multi-process loading with two worker processes
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12)))
[tensor([3]), tensor([5]), tensor([4]), tensor([6])]

Example 2: splitting workload across all workers using worker_init_fn示例 2:使用 worker_init_fn 跨所有工作程序分配工作負載

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]

>>> # Define a `worker_init_fn` that configures each dataset copy differently
>>> def worker_init_fn(worker_id):
...     worker_info = torch.utils.data.get_worker_info()
...     dataset = worker_info.dataset  # the dataset copy in this worker process
...     overall_start = dataset.start
...     overall_end = dataset.end
...     # configure the dataset to only process the split workload
...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
...     worker_id = worker_info.id
...     dataset.start = overall_start + worker_id * per_worker
...     dataset.end = min(dataset.start + per_worker, overall_end)
...

>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]
class torch.utils.data.TensorDataset(*tensors)[source]#

Dataset wrapping tensors.包裝張量的資料集。

Each sample will be retrieved by indexing tensors along the first dimension.每個樣本將透過沿第一個維度索引張量來檢索。

引數

*tensors (Tensor) – tensors that have the same size of the first dimension. **\*tensors**(Tensor)– 具有相同第一個維度的張量。

class torch.utils.data.StackDataset(*args, **kwargs)[source]#

Dataset as a stacking of multiple datasets.將多個數據集堆疊起來的資料集。

This class is useful to assemble different parts of complex input data, given as datasets.此類有助於組合複雜的輸入資料的不同部分,這些部分以資料集的形式提供。

示例

>>> images = ImageDataset()
>>> texts = TextDataset()
>>> tuple_stack = StackDataset(images, texts)
>>> tuple_stack[0] == (images[0], texts[0])
>>> dict_stack = StackDataset(image=images, text=texts)
>>> dict_stack[0] == {"image": images[0], "text": texts[0]}
引數
  • *args (Dataset) – Datasets for stacking returned as tuple. **\*args**(Dataset)– 作為元組返回的用於堆疊的資料集。

  • **kwargs (Dataset) – Datasets for stacking returned as dict. **\*\*kwargs**(Dataset)– 作為字典返回的用於堆疊的資料集。

class torch.utils.data.ConcatDataset(datasets)[source]#

Dataset as a concatenation of multiple datasets.將多個數據集串聯起來的資料集。

This class is useful to assemble different existing datasets.此類有助於組合不同的現有資料集。

引數

datasets (sequence) – List of datasets to be concatenated **datasets**(sequence)– 要串聯的資料集列表

class torch.utils.data.ChainDataset(datasets)[source]#

Dataset for chaining multiple IterableDataset s.用於串聯多個 IterableDataset 的資料集。

This class is useful to assemble different existing dataset streams. The chaining operation is done on-the-fly, so concatenating large-scale datasets with this class will be efficient.此類有助於組合不同的現有資料集流。串聯操作是即時完成的,因此使用此類串聯大規模資料集將是高效的。

引數

datasets (iterable of IterableDataset) – datasets to be chained together **datasets**(IterableDatasetiterable)– 要串聯在一起的資料集

class torch.utils.data.Subset(dataset, indices)[source]#

Subset of a dataset at specified indices.在指定索引處的資料集子集。

引數
  • dataset (Dataset) – The whole Dataset **dataset**(Dataset)– 整個資料集

  • indices (sequence) – Indices in the whole set selected for subset **indices**(sequence)– 在整個集合中選取的用於子集的索引

torch.utils.data._utils.collate.collate(batch, *, collate_fn_map=None)[source]#

General collate function that handles collection type of element within each batch.處理批次內元素集合型別的通用合併函式。

The function also opens function registry to deal with specific element types. default_collate_fn_map provides default collate functions for tensors, numpy arrays, numbers and strings.該函式還提供函式登錄檔來處理特定的元素型別。default_collate_fn_map 為張量、numpy 陣列、數字和字串提供了預設的合併函式。

引數
  • batch – a single batch to be collated **batch**– 要合併的單個批次

  • collate_fn_map (Optional[dict[Union[type, tuple[type, ...]], Callable]]) – Optional dictionary mapping from element type to the corresponding collate function. If the element type isn’t present in this dictionary, this function will go through each key of the dictionary in the insertion order to invoke the corresponding collate function if the element type is a subclass of the key. **collate_fn_map**(Optional[dict[Union[type, tuple[type, ...]], Callable]])– 可選字典,將元素型別對映到相應的合併函式。如果元素型別不在該字典中,此函式將按插入順序遍歷字典中的每個鍵,如果元素型別是鍵的子類,則呼叫相應的合併函式。

示例

>>> def collate_tensor_fn(batch, *, collate_fn_map):
...     # Extend this function to handle batch of tensors
...     return torch.stack(batch, 0)
>>> def custom_collate(batch):
...     collate_map = {torch.Tensor: collate_tensor_fn}
...     return collate(batch, collate_fn_map=collate_map)
>>> # Extend `default_collate` by in-place modifying `default_collate_fn_map`
>>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn})

注意

Each collate function requires a positional argument for batch and a keyword argument for the dictionary of collate functions as collate_fn_map.每個合併函式都需要一個用於批次的 positional 引數和一個用於合併函式字典的 keyword 引數,即 collate_fn_map

torch.utils.data.default_collate(batch)[source]#

Take in a batch of data and put the elements within the batch into a tensor with an additional outer dimension - batch size.接收一個數據批次,並將批次內的元素放入一個具有額外外層維度(批次大小)的張量中。

The exact output type can be a torch.Tensor, a Sequence of torch.Tensor, a Collection of torch.Tensor, or left unchanged, depending on the input type. This is used as the default function for collation when batch_size or batch_sampler is defined in DataLoader.確切的輸出型別可以是 torch.Tensortorch.TensorSequencetorch.Tensor 的 Collection,或者保持不變,具體取決於輸入型別。當在 DataLoader 中定義 batch_sizebatch_sampler 時,此函式用作預設的合併函式。

Here is the general input type (based on the type of the element within the batch) to output type mapping以下是通用的輸入型別(基於批次中元素的型別)到輸出型別的對映:

  • torch.Tensor -> torch.Tensor (with an added outer dimension batch size) torch.Tensor -> torch.Tensor(添加了外層批次大小維度)

  • NumPy Arrays -> torch.Tensor NumPy 陣列 -> torch.Tensor

  • float -> torch.Tensor float -> torch.Tensor

  • int -> torch.Tensor int -> torch.Tensor

  • str -> str (unchanged) str -> str(不變)

  • bytes -> bytes (unchanged) bytes -> bytes(不變)

  • Mapping[K, V_i] -> Mapping[K, default_collate([V_1, V_2, …])] Mapping[K, V_i] -> Mapping[K, default_collate([V_1, V_2, …])]

  • NamedTuple[V1_i, V2_i, …] -> NamedTuple[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …] NamedTuple[V1_i, V2_i, …] -> NamedTuple[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]

  • Sequence[V1_i, V2_i, …] -> Sequence[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …] Sequence[V1_i, V2_i, …] -> Sequence[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]

引數

batch – a single batch to be collated **batch**– 要合併的單個批次

示例

>>> # Example with a batch of `int`s:
>>> default_collate([0, 1, 2, 3])
tensor([0, 1, 2, 3])
>>> # Example with a batch of `str`s:
>>> default_collate(["a", "b", "c"])
['a', 'b', 'c']
>>> # Example with `Map` inside the batch:
>>> default_collate([{"A": 0, "B": 1}, {"A": 100, "B": 100}])
{'A': tensor([  0, 100]), 'B': tensor([  1, 100])}
>>> # Example with `NamedTuple` inside the batch:
>>> Point = namedtuple("Point", ["x", "y"])
>>> default_collate([Point(0, 0), Point(1, 1)])
Point(x=tensor([0, 1]), y=tensor([0, 1]))
>>> # Example with `Tuple` inside the batch:
>>> default_collate([(0, 1), (2, 3)])
[tensor([0, 2]), tensor([1, 3])]
>>> # Example with `List` inside the batch:
>>> default_collate([[0, 1], [2, 3]])
[tensor([0, 2]), tensor([1, 3])]
>>> # Two options to extend `default_collate` to handle specific type
>>> # Option 1: Write custom collate function and invoke `default_collate`
>>> def custom_collate(batch):
...     elem = batch[0]
...     if isinstance(elem, CustomType):  # Some custom condition
...         return ...
...     else:  # Fall back to `default_collate`
...         return default_collate(batch)
>>> # Option 2: In-place modify `default_collate_fn_map`
>>> def collate_customtype_fn(batch, *, collate_fn_map=None):
...     return ...
>>> default_collate_fn_map.update(CustomType, collate_customtype_fn)
>>> default_collate(batch)  # Handle `CustomType` automatically
torch.utils.data.default_convert(data)[source]#

Convert each NumPy array element into a torch.Tensor.將每個 NumPy 陣列元素轉換為 torch.Tensor

If the input is a Sequence, Collection, or Mapping, it tries to convert each element inside to a torch.Tensor. If the input is not an NumPy array, it is left unchanged. This is used as the default function for collation when both batch_sampler and batch_size are NOT defined in DataLoader.如果輸入是 SequenceCollectionMapping,它會嘗試將內部的每個元素轉換為 torch.Tensor。如果輸入不是 NumPy 陣列,則保持不變。當 DataLoader 中沒有定義 batch_samplerbatch_size 時,此函式用作預設的合併函式。

The general input type to output type mapping is similar to that of default_collate(). See the description there for more details.通用的輸入型別到輸出型別的對映與 default_collate() 類似。有關更多詳細資訊,請參閱那裡的描述。

引數

data – a single data point to be converted **data**– 要轉換的單個數據點

示例

>>> # Example with `int`
>>> default_convert(0)
0
>>> # Example with NumPy array
>>> default_convert(np.array([0, 1]))
tensor([0, 1])
>>> # Example with NamedTuple
>>> Point = namedtuple("Point", ["x", "y"])
>>> default_convert(Point(0, 0))
Point(x=0, y=0)
>>> default_convert(Point(np.array(0), np.array(0)))
Point(x=tensor(0), y=tensor(0))
>>> # Example with List
>>> default_convert([np.array([0, 1]), np.array([2, 3])])
[tensor([0, 1]), tensor([2, 3])]
torch.utils.data.get_worker_info()[source]#

Returns the information about the current DataLoader iterator worker process.返回當前 DataLoader 迭代器工作程序的資訊。

When called in a worker, this returns an object guaranteed to have the following attributes在工作程序中呼叫時,這會返回一個保證具有以下屬性的物件:

  • id: the current worker id. id:當前工作程序 ID。

  • num_workers: the total number of workers. num_workers:工作程序總數。

  • seed: the random seed set for the current worker. This value is determined by main process RNG and the worker id. See DataLoader’s documentation for more details. seed:為當前工作程序設定的隨機種子。此值由主程序 RNG 和工作程序 ID 確定。有關更多詳細資訊,請參閱 DataLoader 的文件。

  • dataset: the copy of the dataset object in this process. Note that this will be a different object in a different process than the one in the main process. dataset程序中資料集物件的副本。請注意,這在主程序的物件中將是不同程序中的不同物件。

When called in the main process, this returns None.在主程序中呼叫時,返回 None

注意

When used in a worker_init_fn passed over to DataLoader, this method can be useful to set up each worker process differently, for instance, using worker_id to configure the dataset object to only read a specific fraction of a sharded dataset, or use seed to seed other libraries used in dataset code.當在傳遞給 DataLoaderworker_init_fn 中使用時,此方法可用於以不同方式設定每個工作程序,例如,使用 worker_id 配置 dataset 物件以僅讀取分片資料集的特定部分,或使用 seed 為資料集程式碼中使用的其他庫設定種子。

返回型別

Optional[WorkerInfo]

torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)[source]#

Randomly split a dataset into non-overlapping new datasets of given lengths.將資料集隨機分割成給定長度的非重疊新資料集。

If a list of fractions that sum up to 1 is given, the lengths will be computed automatically as floor(frac * len(dataset)) for each fraction provided.如果給出加起來等於 1 的分數列表,則長度將自動計算為每個提供的分數的 floor(frac \* len(dataset))。

After computing the lengths, if there are any remainders, 1 count will be distributed in round-robin fashion to the lengths until there are no remainders left.計算長度後,如果存在任何餘數,則將以迴圈方式將 1 個計數分配給長度,直到沒有餘數為止。

Optionally fix the generator for reproducible results, e.g.可選地固定生成器以獲得可重現的結果,例如

示例

>>> generator1 = torch.Generator().manual_seed(42)
>>> generator2 = torch.Generator().manual_seed(42)
>>> random_split(range(10), [3, 7], generator=generator1)
>>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2)
引數
  • dataset (Dataset) – Dataset to be split **dataset**(Dataset)– 要分割的資料集

  • lengths (sequence) – lengths or fractions of splits to be produced **lengths**(sequence)– 要生成的分割的長度或分數

  • generator (Generator) – Generator used for the random permutation. **generator**(Generator)– 用於隨機排列的生成器。

返回型別

list[torch.utils.data.dataset.Subset[~_T]]

class torch.utils.data.Sampler(data_source=None)[source]#

Base class for all Samplers.所有采樣器的基類。

Every Sampler subclass has to provide an __iter__() method, providing a way to iterate over indices or lists of indices (batches) of dataset elements, and may provide a __len__() method that returns the length of the returned iterators.每個 Sampler 子類都必須提供一個 __iter__() 方法,提供一種迭代資料集元素索引或索引列表(批次)的方法,並且可以提供一個返回迭代器長度的 __len__() 方法。

引數

data_source (Dataset) – This argument is not used and will be removed in 2.2.0. You may still have custom implementation that utilizes it. **data_source**(Dataset)– 此引數未使用,將在 2.2.0 中刪除。您可能仍有自定義實現使用它。

示例

>>> class AccedingSequenceLengthSampler(Sampler[int]):
>>>     def __init__(self, data: List[str]) -> None:
>>>         self.data = data
>>>
>>>     def __len__(self) -> int:
>>>         return len(self.data)
>>>
>>>     def __iter__(self) -> Iterator[int]:
>>>         sizes = torch.tensor([len(x) for x in self.data])
>>>         yield from torch.argsort(sizes).tolist()
>>>
>>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]):
>>>     def __init__(self, data: List[str], batch_size: int) -> None:
>>>         self.data = data
>>>         self.batch_size = batch_size
>>>
>>>     def __len__(self) -> int:
>>>         return (len(self.data) + self.batch_size - 1) // self.batch_size
>>>
>>>     def __iter__(self) -> Iterator[List[int]]:
>>>         sizes = torch.tensor([len(x) for x in self.data])
>>>         for batch in torch.chunk(torch.argsort(sizes), len(self)):
>>>             yield batch.tolist()

注意

The __len__() method isn’t strictly required by DataLoader, but is expected in any calculation involving the length of a DataLoader. __len__() 方法不是 DataLoader 嚴格必需的,但在涉及 DataLoader 長度的任何計算中都期望存在。

class torch.utils.data.SequentialSampler(data_source)[source]#

Samples elements sequentially, always in the same order.按順序取樣元素,始終以相同的順序。

引數

data_source (Dataset) – dataset to sample from **data_source**(Dataset)– 要取樣的個數據集

class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None, generator=None)[source]#

Samples elements randomly. If without replacement, then sample from a shuffled dataset.隨機取樣元素。如果不放回,則從洗牌的資料集中取樣。

If with replacement, then user can specify num_samples to draw.如果放回,則使用者可以指定 num_samples 進行抽取。

引數
  • data_source (Dataset) – dataset to sample from **data_source**(Dataset)– 要取樣的個數據集

  • replacement (bool) – samples are drawn on-demand with replacement if True, default=``False`` **replacement**(bool)– 如果為 True,則樣本是按需放回抽樣的,預設為 ``False``。

  • num_samples (int) – number of samples to draw, default=`len(dataset)`. **num_samples**(int)– 要抽取的樣本數,預設為 `len(dataset)`。

  • generator (Generator) – Generator used in sampling. **generator**(Generator)– 在取樣中使用的生成器。

class torch.utils.data.SubsetRandomSampler(indices, generator=None)[source]#

Samples elements randomly from a given list of indices, without replacement.從給定的索引列表中隨機取樣元素,不放回。

引數
  • indices (sequence) – a sequence of indices **indices**(sequence)– 索引序列

  • generator (Generator) – Generator used in sampling. **generator**(Generator)– 在取樣中使用的生成器。

class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None)[source]#

Samples elements from [0,..,len(weights)-1] with given probabilities (weights).根據給定的機率(權重)從 [0,..,len(weights)-1] 中取樣元素。

引數
  • weights (sequence) – a sequence of weights, not necessary summing up to one **weights**(sequence)– 權重序列,不一定加起來等於一

  • num_samples (int) – number of samples to draw **num_samples**(int)– 要抽取的樣本數

  • replacement (bool) – if True, samples are drawn with replacement. If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row. **replacement**(bool)– 如果為 True,則樣本是放回抽樣的。如果不是,則是不放回抽樣,這意味著當一行樣本索引被抽取時,該行不能再次被抽取。

  • generator (Generator) – Generator used in sampling. **generator**(Generator)– 在取樣中使用的生成器。

示例

>>> list(
...     WeightedRandomSampler(
...         [0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True
...     )
... )
[4, 4, 1, 4, 5]
>>> list(
...     WeightedRandomSampler(
...         [0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False
...     )
... )
[0, 1, 4, 3, 2]
class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)[source]#

Wraps another sampler to yield a mini-batch of indices.包裝另一個取樣器以產生一個索引的 mini-batch。

引數
  • sampler (Sampler or Iterable) – Base sampler. Can be any iterable object **sampler**(SamplerIterable)– 基本取樣器。可以是任何可迭代物件

  • batch_size (int) – Size of mini-batch. **batch_size**(int)– mini-batch 的大小。

  • drop_last (bool) – If True,sampler 將會丟棄最後一個批次(batch),如果它的尺寸小於 batch_size

示例

>>> list(
...     BatchSampler(
...         SequentialSampler(range(10)), batch_size=3, drop_last=False
...     )
... )
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(
...     BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)
... )
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False)[source]#

限制資料載入到資料集子集的 Sampler。

它尤其適用於與 torch.nn.parallel.DistributedDataParallel 結合使用。在這種情況下,每個程序都可以將一個 DistributedSampler 例項作為 DataLoader 的 sampler,並載入原始資料集的、僅屬於它的一個子集。

注意

假定資料集的大小是恆定的,並且它的任何例項始終以相同的順序返回相同的元素。

引數
  • dataset (Dataset) – 用於取樣的 Dataset。

  • num_replicas (int, optional) – 參與分散式訓練的程序數。預設情況下,從當前分散式組檢索 world_size

  • rank (int, optional) – 當前程序在 num_replicas 中的 rank。預設情況下,從當前分散式組檢索 rank

  • shuffle (bool, optional) – 如果為 True (預設),sampler 將會打亂索引。

  • seed (int, optional) – 如果 shuffle=True,用於打亂 sampler 的隨機種子。此數字在分散式組的所有程序中應保持一致。預設值:0

  • drop_last (bool, optional) – 如果為 True,則 sampler 將丟棄資料的尾部,使其能被副本數整除。如果為 False,sampler 將新增額外的索引以使資料能被副本數整除。預設值:False

警告

在分散式模式下,在每個 epoch 開始時、建立 DataLoader 迭代器 **之前** 呼叫 set_epoch() 方法對於確保多個 epoch 之間正確打亂順序是必要的。否則,將始終使用相同的順序。

示例

>>> sampler = DistributedSampler(dataset) if is_distributed else None
>>> loader = DataLoader(dataset, shuffle=(sampler is None),
...                     sampler=sampler)
>>> for epoch in range(start_epoch, n_epochs):
...     if is_distributed:
...         sampler.set_epoch(epoch)
...     train(loader)