is_tensor_collection¶
- class tensordict.is_tensor_collection(datatype: Union[type, Any])¶
檢查一個數據物件或型別是否來自 tensordict 庫的張量容器。
- 返回:
如果輸入是 TensorDictBase 的子類、tensorclass 或這些的例項,則返回
True。否則返回False。
示例
>>> is_tensor_collection(TensorDictBase) # True >>> is_tensor_collection(TensorDict()) # True >>> @tensorclass ... class MyClass: ... pass ... >>> is_tensor_collection(MyClass) # True >>> is_tensor_collection(MyClass(batch_size=[])) # True