MultiAgentMLP¶
- class torchrl.modules.MultiAgentMLP(n_agent_inputs: int | None, n_agent_outputs: int, n_agents: int, *, centralized: bool | None = None, share_params: bool | None = None, device: DEVICE_TYPING | None = None, depth: int | None = None, num_cells: Sequence | int | None = None, activation_class: type[nn.Module] | None = <class 'torch.nn.modules.activation.Tanh'>, use_td_params: bool = True, **kwargs)[來源]¶
多智慧體 MLP。
這是一個可在多智慧體場景中使用的 MLP。例如,作為策略或價值函式。有關示例,請參閱 examples/multiagent。
它期望輸入形狀為 (*B, n_agents, n_agent_inputs)。它返回的輸出形狀為 (*B, n_agents, n_agent_outputs)。
如果 share_params 為 True,則所有智慧體將使用相同的 MLP 來進行前向傳播(同質策略)。否則,每個智慧體將使用不同的 MLP 來處理其輸入(異質策略)。
如果 centralized 為 True,則每個智慧體將使用所有智慧體的輸入來計算其輸出(n_agent_inputs * n_agents 將是單個智慧體的輸入數量)。否則,每個智慧體將僅使用其自身的資料作為輸入。
- 引數:
n_agent_inputs (int 或 None) – 每個智慧體的輸入數量。如果為
None,則輸入數量將在第一次呼叫時延遲例項化。n_agent_outputs (int) – 每個智慧體的輸出數量。
n_agents (int) – 代理數量。
- 關鍵字引數:
centralized (bool) – 如果 centralized 為 True,則每個智慧體將使用所有智慧體的輸入來計算其輸出(n_agent_inputs * n_agents 將是單個智慧體的輸入數量)。否則,每個智慧體將僅使用其自身的資料作為輸入。
share_params (bool) – 如果 share_params 為 True,則所有智慧體將使用相同的 MLP 來進行前向傳播(同質策略)。否則,每個智慧體將使用不同的 MLP 來處理其輸入(異質策略)。
device (str 或 toech.device, optional) – 用於建立模組的裝置。
depth (int, optional) – 網路的深度。深度為 0 將產生一個具有所需輸入和輸出大小的單個線性層網路。長度為 1 將建立 2 個線性層,依此類推。如果未指定深度,則深度資訊應包含在 num_cells 引數中(見下文)。如果 num_cells 是一個可迭代物件且指定了深度,兩者應匹配:len(num_cells) 必須等於 depth。預設值:3。
num_cells (int 或 Sequence[int], optional) – 輸入和輸出之間的每一層的單元數。如果提供一個整數,則每一層將具有相同的單元數。如果提供一個可迭代物件,則線性層的 out_features 將與 num_cells 的內容匹配。預設值:32。
activation_class (Type[nn.Module]) – 要使用的啟用類。預設值:nn.Tanh。
use_td_params (bool, optional) – 如果為
True,則引數可以在 self.params 中找到,它是一個TensorDictParams物件(它同時繼承自 TensorDict 和 nn.Module)。如果為False,則引數包含在 self._empty_net 中。總而言之,這兩種方法應該大致相同但不可互換:例如,當use_td_params=False時,使用use_td_params=True建立的state_dict不能使用。**kwargs – 可以傳遞給
torchrl.modules.models.MLP以自定義 MLP。
注意
要使用 torch.nn.init 模組初始化 MARL 模組引數,請參閱
get_stateful_net()和from_stateful_net()方法。示例
>>> from torchrl.modules import MultiAgentMLP >>> import torch >>> n_agents = 6 >>> n_agent_inputs=3 >>> n_agent_outputs=2 >>> batch = 64 >>> obs = torch.zeros(batch, n_agents, n_agent_inputs) >>> # instantiate a local network shared by all agents (e.g. a parameter-shared policy) >>> mlp = MultiAgentMLP( ... n_agent_inputs=n_agent_inputs, ... n_agent_outputs=n_agent_outputs, ... n_agents=n_agents, ... centralized=False, ... share_params=True, ... depth=2, ... ) >>> print(mlp) MultiAgentMLP( (agent_networks): ModuleList( (0): MLP( (0): Linear(in_features=3, out_features=32, bias=True) (1): Tanh() (2): Linear(in_features=32, out_features=32, bias=True) (3): Tanh() (4): Linear(in_features=32, out_features=2, bias=True) ) ) ) >>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs) Now let's instantiate a centralized network shared by all agents (e.g. a centalised value function) >>> mlp = MultiAgentMLP( ... n_agent_inputs=n_agent_inputs, ... n_agent_outputs=n_agent_outputs, ... n_agents=n_agents, ... centralized=True, ... share_params=True, ... depth=2, ... ) >>> print(mlp) MultiAgentMLP( (agent_networks): ModuleList( (0): MLP( (0): Linear(in_features=18, out_features=32, bias=True) (1): Tanh() (2): Linear(in_features=32, out_features=32, bias=True) (3): Tanh() (4): Linear(in_features=32, out_features=2, bias=True) ) ) ) We can see that the input to the first layer is n_agents * n_agent_inputs, this is because in the case the net acts as a centralized mlp (like a single huge agent) >>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs) Outputs will be identical for all agents. Now we can do both examples just shown but with an independent set of parameters for each agent Let's show the centralized=False case. >>> mlp = MultiAgentMLP( ... n_agent_inputs=n_agent_inputs, ... n_agent_outputs=n_agent_outputs, ... n_agents=n_agents, ... centralized=False, ... share_params=False, ... depth=2, ... ) >>> print(mlp) MultiAgentMLP( (agent_networks): ModuleList( (0-5): 6 x MLP( (0): Linear(in_features=3, out_features=32, bias=True) (1): Tanh() (2): Linear(in_features=32, out_features=32, bias=True) (3): Tanh() (4): Linear(in_features=32, out_features=2, bias=True) ) ) ) We can see that this is the same as in the first example, but now we have 6 MLPs, one per agent! >>> assert mlp(obs).shape == (batch, n_agents, n_agent_outputs)