get_primers_from_module¶
- class torchrl.modules.utils.get_primers_from_module(module)[原始碼]¶
從模組的所有子模組中獲取所有 tensordict primr。
此方法有助於從包含在父模組中的模組中檢索 primers。
- 引數:
module (torch.nn.Module) – 父模組。
- 返回:
一個 TensorDictPrimer Transform。
- 返回型別:
示例
>>> from torchrl.modules.utils import get_primers_from_module >>> from torchrl.modules import GRUModule, MLP >>> from tensordict.nn import TensorDictModule, TensorDictSequential >>> # Define a GRU module >>> gru_module = GRUModule( ... input_size=10, ... hidden_size=10, ... num_layers=1, ... in_keys=["input", "recurrent_state", "is_init"], ... out_keys=["features", ("next", "recurrent_state")], ... ) >>> # Define a head module >>> head = TensorDictModule( ... MLP( ... in_features=10, ... out_features=10, ... num_cells=[], ... ), ... in_keys=["features"], ... out_keys=["output"], ... ) >>> # Create a sequential model >>> model = TensorDictSequential(gru_module, head) >>> # Retrieve primers from the model >>> primers = get_primers_from_module(model) >>> print(primers)
- TensorDictPrimer(primers=Composite(
- recurrent_state: UnboundedContinuous(
shape=torch.Size([1, 10]), space=None, device=cpu, dtype=torch.float32, domain=continuous), device=None, shape=torch.Size([])), default_value={‘recurrent_state’: 0.0}, random=None)