快捷方式

set_list_to_stack

class tensordict.set_list_to_stack(mode: (Python v3.13))

用於控制 TensorDict 中列表處理行為的上下文管理器和裝飾器。

啟用後,分配給 TensorDict 的列表將自動沿著批次維度堆疊。這對於確保列表中的張量或其他元素在 TensorDict 中被視為可堆疊實體非常有用。

當前行為
如果未透過此上下文管理器將列表分配給 TensorDict,它將被轉換為 numpy 陣列

幷包裝在 NonTensorData 中,如果它無法轉換為張量。

引數:

mode ((Python v3.13)bool) – 如果為 True,則啟用列表到堆疊的轉換。如果為 False,則停用它。

示例

>>> with set_list_to_stack(True):
...     td = TensorDict(a=[torch.zeros(()), torch.ones(())], batch_size=2)
...     assert (td["a"] == torch.tensor([0, 1])).all()
...     assert td[0]["a"] == 0
...     assert td[1]["a"] == 1

另請參閱

list_to_stack().

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源