from_pytree¶
- class tensordict.from_pytree(pytree, *, batch_size: Optional[Size] = None, auto_batch_size: bool = False, batch_dims: Optional[int] = None)¶
將 pytree 轉換為 TensorDict 例項。
此方法旨在儘可能保留 pytree 的巢狀結構。
其他非張量鍵將被新增,以跟蹤每個級別的標識,從而提供內建的 pytree 到 tensordict 的雙射轉換 API。
當前接受的類包括列表、元組、命名元組和字典。
注意
對於字典,非 `NestedKey` 鍵被單獨註冊為 `NonTensorData` 例項。
注意
可轉換為張量型別(如 int、float 或 np.ndarray)將被轉換為 torch.Tensor 例項。請注意,此轉換是滿射的:將 tensordict 轉換回 pytree 將無法恢復原始型別。
示例
>>> # Create a pytree with tensor leaves, and one "weird"-looking dict key >>> class WeirdLookingClass: ... pass ... >>> weird_key = WeirdLookingClass() >>> # Make a pytree with tuple, lists, dict and namedtuple >>> pytree = ( ... [torch.randint(10, (3,)), torch.zeros(2)], ... { ... "tensor": torch.randn( ... 2, ... ), ... "td": TensorDict({"one": 1}), ... weird_key: torch.randint(10, (2,)), ... "list": [1, 2, 3], ... }, ... {"named_tuple": TensorDict({"two": torch.ones(1) * 2}).to_namedtuple()}, ... ) >>> # Build a TensorDict from that pytree >>> td = from_pytree(pytree) >>> # Recover the pytree >>> pytree_recon = td.to_pytree() >>> # Check that the leaves match >>> def check(v1, v2): >>> assert (v1 == v2).all() >>> >>> torch.utils._pytree.tree_map(check, pytree, pytree_recon) >>> assert weird_key in pytree_recon[1]