快捷方式

isin

class tensordict.utils.isin(input: TensorDictBase, reference: TensorDictBase, key: NestedKey, dim: int = 0)

測試 inputkey 的每個元素是否也存在於 reference 中(沿 dim 維度)。

此函式返回一個長度為 input.batch_size[dim] 的布林張量,對於 key 條目中也存在於 reference 中的元素,其值為 True。該函式假定 inputreference 具有相同的批次大小幷包含指定的條目,否則將引發錯誤。

引數:
  • input (TensorDictBase) – 輸入 TensorDict。

  • reference (TensorDictBase) – 用於測試的目標 TensorDict。

  • key (Nestedkey) – 要測試的鍵。

  • dim (int, optional) – 要測試的維度。預設為 0

返回:

一個長度為 input.batch_size[dim] 的布林張量,對於

存在於 input key 張量中且也存在於 reference 中的元素,其值為 True

返回型別:

out (Tensor)

示例

>>> td = TensorDict(
...     {
...         "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [7, 8, 9]]),
...         "tensor2": torch.tensor([[10, 20], [30, 40], [40, 50], [50, 60]]),
...     },
...     batch_size=[4],
... )
>>> td_ref = TensorDict(
...     {
...         "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [10, 11, 12]]),
...         "tensor2": torch.tensor([[10, 20], [30, 40], [50, 60]]),
...     },
...     batch_size=[3],
... )
>>> in_reference = isin(td, td_ref, key="tensor1")
>>> expected_in_reference = torch.tensor([True, True, True, False])
>>> torch.testing.assert_close(in_reference, expected_in_reference)

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源