• 文件 >
  • 多工環境中的特定任務策略
快捷方式

多工環境中的特定任務策略

本教程詳細介紹瞭如何使用多工策略和批次環境。

在本教程結束時,您將能夠編寫策略,使用不同的權重集在各種環境中計算動作。您還將能夠並行執行各種環境。

from tensordict import LazyStackedTensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential
from torch import nn
from torchrl.envs import CatTensors, Compose, DoubleToFloat, ParallelEnv, TransformedEnv
from torchrl.envs.libs.dm_control import DMControlEnv
from torchrl.modules import MLP

我們設計了兩個環境:一個擬人化機器人必須完成站立任務,另一個機器人必須學會行走。

env1 = DMControlEnv("humanoid", "stand")
env1_obs_keys = list(env1.observation_spec.keys())
env1 = TransformedEnv(
    env1,
    Compose(
        CatTensors(env1_obs_keys, "observation_stand", del_keys=False),
        CatTensors(env1_obs_keys, "observation"),
        DoubleToFloat(
            in_keys=["observation_stand", "observation"],
            in_keys_inv=["action"],
        ),
    ),
)
env2 = DMControlEnv("humanoid", "walk")
env2_obs_keys = list(env2.observation_spec.keys())
env2 = TransformedEnv(
    env2,
    Compose(
        CatTensors(env2_obs_keys, "observation_walk", del_keys=False),
        CatTensors(env2_obs_keys, "observation"),
        DoubleToFloat(
            in_keys=["observation_walk", "observation"],
            in_keys_inv=["action"],
        ),
    ),
)
tdreset1 = env1.reset()
tdreset2 = env2.reset()

# With LazyStackedTensorDict, stacking is done in a lazy manner: the original tensordicts
# can still be recovered by indexing the main tensordict
tdreset = LazyStackedTensorDict.lazy_stack([tdreset1, tdreset2], 0)
assert tdreset[0] is tdreset1
print(tdreset[0])

策略 (Policy)

我們將設計一個策略,其中一個主幹讀取“observation”鍵。然後,特定的子元件將讀取堆疊的 tensordicts 中的“observation_stand”和“observation_walk”鍵(如果它們存在),並將它們透過專用子網路傳遞。

action_dim = env1.action_spec.shape[-1]
policy_common = TensorDictModule(
    nn.Linear(67, 64), in_keys=["observation"], out_keys=["hidden"]
)
policy_stand = TensorDictModule(
    MLP(67 + 64, action_dim, depth=2),
    in_keys=["observation_stand", "hidden"],
    out_keys=["action"],
)
policy_walk = TensorDictModule(
    MLP(67 + 64, action_dim, depth=2),
    in_keys=["observation_walk", "hidden"],
    out_keys=["action"],
)
seq = TensorDictSequential(
    policy_common, policy_stand, policy_walk, partial_tolerant=True
)

讓我們檢查一下我們的序列是否為單個環境(站立)輸出了動作。

seq(env1.reset())

讓我們檢查一下我們的序列是否為單個環境(行走)輸出了動作。

seq(env2.reset())

這也能與堆疊一起工作:現在站立和行走鍵已消失,因為它們不是所有 tensordicts 都共享的。但是 TensorDictSequential 仍然執行了操作。請注意,主幹是以向量化的方式執行的,而不是迴圈執行,這樣更有效率。

seq(tdreset)

並行執行不同的任務

如果通用鍵值對共享相同的規範(特別是它們的形狀和資料型別必須匹配:如果觀察值的形狀不同但由同一個鍵指向,則無法執行以下操作),則可以並行化操作。

如果 ParallelEnv 接收到單個 env 工廠函式,它將假定必須執行單個任務。如果提供函式列表,它將假定我們處於多工設定中。

def env1_maker():
    return TransformedEnv(
        DMControlEnv("humanoid", "stand"),
        Compose(
            CatTensors(env1_obs_keys, "observation_stand", del_keys=False),
            CatTensors(env1_obs_keys, "observation"),
            DoubleToFloat(
                in_keys=["observation_stand", "observation"],
                in_keys_inv=["action"],
            ),
        ),
    )


def env2_maker():
    return TransformedEnv(
        DMControlEnv("humanoid", "walk"),
        Compose(
            CatTensors(env2_obs_keys, "observation_walk", del_keys=False),
            CatTensors(env2_obs_keys, "observation"),
            DoubleToFloat(
                in_keys=["observation_walk", "observation"],
                in_keys_inv=["action"],
            ),
        ),
    )


env = ParallelEnv(2, [env1_maker, env2_maker])
assert not env._single_task

tdreset = env.reset()
print(tdreset)
print(tdreset[0])
print(tdreset[1])  # should be different

讓我們將輸出傳遞給我們的網路。

tdreset = seq(tdreset)
print(tdreset)
print(tdreset[0])
print(tdreset[1])  # should be different but all have an "action" key


env.step(tdreset)  # computes actions and execute steps in parallel
print(tdreset)
print(tdreset[0])
print(tdreset[1])  # next_observation has now been written

回滾 (Rollout)

td_rollout = env.rollout(100, policy=seq, return_contiguous=False)
td_rollout[:, 0]  # tensordict of the first step: only the common keys are shown
td_rollout[0]  # tensordict of the first env: the stand obs is present

env.close()
del env

由 Sphinx-Gallery 生成的畫廊

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源