set_list_to_stack¶
- class tensordict.set_list_to_stack(mode: (Python v3.13))¶
用於控制 TensorDict 中列表處理行為的上下文管理器和裝飾器。
啟用後,分配給 TensorDict 的列表將自動沿著批次維度堆疊。這對於確保列表中的張量或其他元素在 TensorDict 中被視為可堆疊實體非常有用。
- 當前行為
- 如果未透過此上下文管理器將列表分配給 TensorDict,它將被轉換為 numpy 陣列
幷包裝在 NonTensorData 中,如果它無法轉換為張量。
- 引數:
mode ((Python v3.13)bool) – 如果為 True,則啟用列表到堆疊的轉換。如果為 False,則停用它。
示例
>>> with set_list_to_stack(True): ... td = TensorDict(a=[torch.zeros(()), torch.ones(())], batch_size=2) ... assert (td["a"] == torch.tensor([0, 1])).all() ... assert td[0]["a"] == 0 ... assert td[1]["a"] == 1
另請參閱