make_tensordict¶
- class tensordict.make_tensordict(input_dict: Optional[dict[str, Union[torch.Tensor, tensordict._tensorcollection.TensorCollection]]] = None, batch_size: Optional[Union[Sequence[int], Size, int]] = None, device: Optional[Union[device, str, int]] = None, auto_batch_size: Optional[bool] = None, **kwargs: Union[Tensor, TensorCollection])¶
從關鍵字引數或輸入字典返回一個建立的 TensorDict。
If
batch_sizeis not specified, returns the maximum batch size possible.This function works on nested dictionaries too, or can be used to determine the batch-size of a nested tensordict.
- 引數:
input_dict (dictionary, optional) – a dictionary to use as a data source (nested keys compatible).
**kwargs (TensorDict 或 torch.Tensor) – 關鍵字引數作為資料來源(與巢狀鍵不相容)。
batch_size (iterable of int, optional) – a batch size for the tensordict.
device (torch.device 或 相容型別, optional) – TensorDict 的裝置。
auto_batch_size (bool, 可選) – 如果為
True,則會自動計算批次大小。預設為False。
示例
>>> input_dict = {"a": torch.randn(3, 4), "b": torch.randn(3)} >>> print(make_tensordict(input_dict)) TensorDict( fields={ a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), b: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> # alternatively >>> td = make_tensordict(**input_dict) >>> # nested dict: the nested TensorDict can have a different batch-size >>> # as long as its leading dims match. >>> input_dict = {"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}} >>> print(make_tensordict(input_dict)) TensorDict( fields={ a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3, 4]), device=None, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> # we can also use this to work out the batch sie of a tensordict >>> input_td = TensorDict({"a": torch.randn(3), "b": {"c": torch.randn(3, 4)}}, []) >>> print(make_tensordict(input_td)) TensorDict( fields={ a: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3, 4]), device=None, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False)