快捷方式

MCTSForest

class torchrl.data.MCTSForest(*, data_map: TensorDictMap | None = None, node_map: TensorDictMap | None = None, max_size: int | None = None, done_keys: list[NestedKey] | None = None, reward_keys: list[NestedKey] = None, observation_keys: list[NestedKey] = None, action_keys: list[NestedKey] = None, excluded_keys: list[NestedKey] = None, consolidated: bool | None = None)[source]

MCTS 樹的集合。

警告

此類目前處於積極開發中。請注意 API 可能會頻繁更改。

此類旨在將 rollouts 儲存在儲存中,並根據該資料集中給定的根生成樹。

關鍵字引數:
  • data_map (TensorDictMap, 可選) – 用於儲存資料(觀測、獎勵、狀態等)的儲存。如果未提供,它將使用 observation_keysaction_keys 作為 in_keys,透過 from_tensordict_pair() 惰性初始化。

  • node_map (TensorDictMap, 可選) – 從觀測空間到索引空間的對映。在內部,node_map 用於收集來自給定節點的所有可能分支。例如,如果一個觀測在 data map 中有兩個相關的動作和結果,那麼 node_map 將返回一個包含 data_map 中對應於這兩個結果的資料結構。如果未提供,它將使用 observation_keys 列表作為 in_keys,並使用 QueryModule 作為 out_keys,透過 from_tensordict_pair() 惰性初始化。

  • max_size (int, 可選) – 對映的大小。如果未提供,則預設為 data_map.max_size(如果可以找到),然後是 node_map.max_size。如果以上都沒有提供,則預設為 1000

  • done_keys (NestedKey 列表, 可選) – 環境的 done 鍵。如果未提供,則預設為 ("done", "terminated", "truncated")get_keys_from_env() 可用於自動確定鍵。

  • action_keys (NestedKey 列表, 可選) – 環境的 action 鍵。如果未提供,則預設為 ("action",)get_keys_from_env() 可用於自動確定鍵。

  • reward_keys (NestedKey 列表, 可選) – 環境的 reward 鍵。如果未提供,則預設為 ("reward",)get_keys_from_env() 可用於自動確定鍵。

  • observation_keys (NestedKey 列表, 可選) – 環境的 observation 鍵。如果未提供,則預設為 ("observation",)get_keys_from_env() 可用於自動確定鍵。

  • excluded_keys (NestedKey 列表, 可選) – 要從資料儲存中排除的鍵列表。

  • consolidated (bool, 可選) – 如果為 True,則 data_map 儲存將在磁碟上進行合併。預設為 False

示例

>>> from torchrl.envs import GymEnv
>>> import torch
>>> from tensordict import TensorDict, LazyStackedTensorDict
>>> from torchrl.data import TensorDictMap, ListStorage
>>> from torchrl.data.map.tree import MCTSForest
>>>
>>> from torchrl.envs import PendulumEnv, CatTensors, UnsqueezeTransform, StepCounter
>>> # Create the MCTS Forest
>>> forest = MCTSForest()
>>> # Create an environment. We're using a stateless env to be able to query it at any given state (like an oracle)
>>> env = PendulumEnv()
>>> obs_keys = list(env.observation_spec.keys(True, True))
>>> state_keys = set(env.full_state_spec.keys(True, True)) - set(obs_keys)
>>> # Appending transforms to get an "observation" key that concatenates the observations together
>>> env = env.append_transform(
...     UnsqueezeTransform(
...         in_keys=obs_keys,
...         out_keys=[("unsqueeze", key) for key in obs_keys],
...         dim=-1
...     )
... )
>>> env = env.append_transform(
...     CatTensors([("unsqueeze", key) for key in obs_keys], "observation")
... )
>>> env = env.append_transform(StepCounter())
>>> env.set_seed(0)
>>> # Get a reset state, then make a rollout out of it
>>> reset_state = env.reset()
>>> rollout0 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone())
>>> # Append the rollout to the forest. We're removing the state entries for clarity
>>> rollout0 = rollout0.copy()
>>> rollout0.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True)
>>> forest.extend(rollout0)
>>> # The forest should have 6 elements (the length of the rollout)
>>> assert len(forest) == 6
>>> # Let's make another rollout from the same reset state
>>> rollout1 = env.rollout(6, auto_reset=False, tensordict=reset_state.clone())
>>> rollout1.exclude(*state_keys, inplace=True).get("next").exclude(*state_keys, inplace=True)
>>> forest.extend(rollout1)
>>> assert len(forest) == 12
>>> # Let's make another final rollout from an intermediate step in the second rollout
>>> rollout1b = env.rollout(6, auto_reset=False, tensordict=rollout1[3].exclude("next"))
>>> rollout1b.exclude(*state_keys, inplace=True)
>>> rollout1b.get("next").exclude(*state_keys, inplace=True)
>>> forest.extend(rollout1b)
>>> assert len(forest) == 18
>>> # Since we have 2 rollouts starting at the same state, our tree should have two
>>> #  branches if we produce it from the reset entry. Take the state, and call `get_tree`:
>>> r = rollout0[0]
>>> # Let's get the compact tree that follows the initial reset. A compact tree is
>>> #  a tree where nodes that have a single child are collapsed.
>>> tree = forest.get_tree(r)
>>> print(tree.max_length())
2
>>> print(list(tree.valid_paths()))
[(0,), (1, 0), (1, 1)]
>>> from tensordict import assert_close
>>> # We can manually rebuild the tree
>>> assert_close(
...     rollout1,
...     torch.cat([tree.subtree[1].rollout, tree.subtree[1].subtree[0].rollout]),
...     intersection=True,
... )
True
>>> # Or we can rebuild it using the dedicated method
>>> assert_close(
...     rollout1,
...     tree.rollout_from_path((1, 0)),
...     intersection=True,
... )
True
>>> tree.plot()
>>> tree = forest.get_tree(r, compact=False)
>>> print(tree.max_length())
9
>>> print(list(tree.valid_paths()))
[(0, 0, 0, 0, 0, 0), (1, 0, 0, 0, 0, 0), (1, 0, 0, 1, 0, 0, 0, 0, 0)]
>>> assert_close(
...     rollout1,
...     tree.rollout_from_path((1, 0, 0, 0, 0, 0)),
...     intersection=True,
... )
True
property action_keys: list[tensordict._nestedkey.NestedKey]

動作鍵。

返回用於從環境輸入中檢索動作的鍵。預設的動作鍵是“action”。

返回:

字串或元組列表,表示動作鍵。

property done_keys: list[tensordict._nestedkey.NestedKey]

Done 鍵。

返回用於指示已結束的 episode 的鍵。預設的 done 鍵是“done”、“terminated”和“truncated”。這些鍵可以在環境的輸出中使用來訊號化 episode 的結束。

返回:

字串列表,表示 done 鍵。

extend(rollout, *, return_node: bool = False)[source]

向 forest 新增一個 rollout。

節點僅在 rollout 彼此分叉的點和 rollout 的終點新增到樹中。

如果不存在與 rollout 的初始步驟匹配的現有樹,則會新增一個新樹。只會建立一個節點,用於最終步驟。

如果存在與 rollout 匹配的現有樹,則將 rollout 新增到該樹中。如果在某個步驟中 rollout 與樹中的所有其他 rollout 分叉,則在 rollout 分叉的步驟之前建立一個新節點,併為 rollout 的最終步驟建立一個葉節點。如果 rollout 的所有步驟都與先前新增的 rollout 匹配,則不會發生任何變化。如果 rollout 匹配到樹的葉節點但超出其範圍,則該節點將擴充套件到 rollout 的末尾,並且不會建立新節點。

引數:
  • rollout (TensorDict) – 要新增到 forest 的 rollout。

  • return_node (bool, 可選) – 如果為 True,則方法返回新增的節點。預設為 False

返回:

新增到 forest 的節點。這僅

return_node 為 True 時返回。

返回型別:

示例

>>> from torchrl.data import MCTSForest
>>> from tensordict import TensorDict
>>> import torch
>>> forest = MCTSForest()
>>> r0 = TensorDict({
...     'action': torch.tensor([1, 2, 3, 4, 5]),
...     'next': {'observation': torch.tensor([123, 392, 989, 809, 847])},
...     'observation': torch.tensor([  0, 123, 392, 989, 809])
... }, [5])
>>> r1 = TensorDict({
...     'action': torch.tensor([1, 2, 6, 7]),
...     'next': {'observation': torch.tensor([123, 392, 235,  38])},
...     'observation': torch.tensor([  0, 123, 392, 235])
... }, [4])
>>> td_root = r0[0].exclude("next")
>>> forest.extend(r0)
>>> forest.extend(r1)
>>> tree = forest.get_tree(td_root)
>>> print(tree)
Tree(
    count=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
    index=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
    node_data=TensorDict(
        fields={
            observation: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
        batch_size=torch.Size([]),
        device=cpu,
        is_shared=False),
    node_id=NonTensorData(data=0, batch_size=torch.Size([]), device=None),
    rollout=TensorDict(
        fields={
            action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
            next: TensorDict(
                fields={
                    observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
                batch_size=torch.Size([2]),
                device=cpu,
                is_shared=False),
            observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
        batch_size=torch.Size([2]),
        device=cpu,
        is_shared=False),
    subtree=Tree(
        _parent=NonTensorStack(
            [<weakref at 0x716eeb78fbf0; to 'TensorDict' at 0x...,
            batch_size=torch.Size([2]),
            device=None),
        count=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int32, is_shared=False),
        hash=NonTensorStack(
            [4341220243998689835, 6745467818783115365],
            batch_size=torch.Size([2]),
            device=None),
        node_data=LazyStackedTensorDict(
            fields={
                observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
            exclusive_fields={
            },
            batch_size=torch.Size([2]),
            device=cpu,
            is_shared=False,
            stack_dim=0),
        node_id=NonTensorStack(
            [1, 2],
            batch_size=torch.Size([2]),
            device=None),
        rollout=LazyStackedTensorDict(
            fields={
                action: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False),
                next: LazyStackedTensorDict(
                    fields={
                        observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)},
                    exclusive_fields={
                    },
                    batch_size=torch.Size([2, -1]),
                    device=cpu,
                    is_shared=False,
                    stack_dim=0),
                observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)},
            exclusive_fields={
            },
            batch_size=torch.Size([2, -1]),
            device=cpu,
            is_shared=False,
            stack_dim=0),
        wins=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
        index=None,
        subtree=None,
        specs=None,
        batch_size=torch.Size([2]),
        device=None,
        is_shared=False),
    wins=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
    hash=None,
    _parent=None,
    specs=None,
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
get_keys_from_env(env: EnvBase)[source]

根據環境向 Forest 寫入缺失的 done、action 和 reward 鍵。

現有鍵不會被覆蓋。

property observation_keys: list[tensordict._nestedkey.NestedKey]

Observation 鍵。

返回用於從環境輸出中檢索觀測的鍵。預設的 observation 鍵是“observation”。

返回:

字串或元組列表,表示 observation 鍵。

property reward_keys: list[tensordict._nestedkey.NestedKey]

Reward 鍵。

返回用於從環境輸出中檢索獎勵的鍵。預設的 reward 鍵是“reward”。

返回:

字串或元組列表,表示 reward 鍵。

to_string(td_root, node_format_fn=<function MCTSForest.<lambda>>)[source]

生成 forest 中樹的字串表示。

此函式可以提取樹中每個節點的資訊,因此對於除錯很有用。節點按行顯示。每行包含節點的路徑,後跟使用 :arg:`node_format_fn` 生成的該節點的字串表示。每行根據到達相應節點所需的路徑步數進行縮排。

引數:
  • td_root (TensorDict) – 樹的根節點。

  • node_format_fn (Callable, 可選) – 使用者定義的函式,用於為樹的每個節點生成字串。簽名必須是 (Tree) -> Any,並且輸出必須可轉換為字串。如果未提供此引數,則生成的字串是節點的 Tree.node_data 屬性轉換為字典。

示例

>>> from torchrl.data import MCTSForest
>>> from tensordict import TensorDict
>>> forest = MCTSForest()
>>> td_root = TensorDict({"observation": 0,})
>>> rollouts_data = [
...     # [(action, obs), ...]
...     [(3, 123), (1, 456)],
...     [(2, 359), (2, 3094)],
...     [(3, 123), (9, 392), (6, 989), (20, 809), (21, 847)],
...     [(1, 75)],
...     [(3, 123), (0, 948)],
...     [(2, 359), (2, 3094), (10, 68)],
...     [(2, 359), (2, 3094), (11, 9045)],
... ]
>>> for rollout_data in rollouts_data:
...     td = td_root.clone().unsqueeze(0)
...     for action, obs in rollout_data:
...         td = td.update(TensorDict({
...             "action": [action],
...             "next": TensorDict({"observation": [obs]}, [1]),
...         }, [1]))
...         forest.extend(td)
...         td = td["next"].clone()
...
>>> print(forest.to_string(td_root))
(0,) {'observation': tensor(123)}
(0, 0) {'observation': tensor(456)}
(0, 1) {'observation': tensor(847)}
(0, 2) {'observation': tensor(948)}
(1,) {'observation': tensor(3094)}
(1, 0) {'observation': tensor(68)}
(1, 1) {'observation': tensor(9045)}
(2,) {'observation': tensor(75)}

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源