OpenMLEnv¶
- torchrl.envs.OpenMLEnv(*args, **kwargs)[原始碼]¶
用於在 bandit 環境中使用的 OpenML 資料環境介面。
文件: https://www.openml.org/search?type=data
Scikit-learn 介面: https://scikit-learn.org/stable/modules/generated/sklearn.datasets.fetch_openml.html
- 引數:
dataset_name (str) – 支援以下資料集:
"adult_num","adult_onehot","mushroom_num","mushroom_onehot","covertype","shuttle"and"magic"。device (torch.device 或 相容型別, 可選) – 預期輸入和輸出資料的裝置。預設為
"cpu"。batch_size (torch.Size 或 相容型別, 可選) – 環境的批次大小,即呼叫
reset()時取樣的元素數量。預設為空批次大小,即一次取樣一個元素。
- 變數:
available_envs (List[str]) – 由此類構建的環境列表。
示例
>>> env = OpenMLEnv("adult_onehot", batch_size=[2, 3]) >>> print(env.reset()) TensorDict( fields={ done: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([2, 3, 106]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([2, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False), y: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([2, 3]), device=cpu, is_shared=False)