terminated_or_truncated¶
- torchrl.envs.terminated_or_truncated(data: TensorDictBase, full_done_spec: TensorSpec | None = None, key: str = '_reset', write_full_false: bool = False) bool[source]¶
讀取 tensordict 中的 done / terminated / truncated 鍵,並寫入一個新張量,其中兩個訊號的值被聚合。
修改將就地發生在提供的 TensorDict 例項中。此函式可用於在批次或多智慧體設定中計算 “_reset” 訊號,因此輸出鍵的預設名稱。
- 引數:
data (TensorDictBase) – 輸入資料,通常來自
step()的呼叫。full_done_spec (TensorSpec, optional) – 來自 env 的 done_spec,指示 done 鍵的位置。如果未提供,則將在資料中搜索預設的
"done"、"terminated"和"truncated"條目。key (NestedKey, optional) –
聚合結果應寫入的位置。如果為
None,則函式將不寫入任何鍵,而僅輸出是否任何 done 值都為 true。.. note:: 如果key條目已存在相應的值,則先前的值將優先,並且不會實現更新。
write_full_false (bool, optional) – 如果為
True,即使輸出為False(即,在提供的資料結構中沒有 done 為True),也將寫入 reset 鍵。預設為False。
- 返回:一個布林值,指示資料中找到的任何 done 狀態
是否包含
True。
示例
>>> from torchrl.data.tensor_specs import Categorical >>> from tensordict import TensorDict >>> spec = Composite( ... done=Categorical(2, dtype=torch.bool), ... truncated=Categorical(2, dtype=torch.bool), ... nested=Composite( ... done=Categorical(2, dtype=torch.bool), ... truncated=Categorical(2, dtype=torch.bool), ... ) ... ) >>> data = TensorDict({ ... "done": True, "truncated": False, ... "nested": {"done": False, "truncated": True}}, ... batch_size=[] ... ) >>> data = _terminated_or_truncated(data, spec) >>> print(data["_reset"]) tensor(True) >>> print(data["nested", "_reset"]) tensor(True)