torch.combinations#
- torch.combinations(input: Tensor, r: int = 2, with_replacement: bool = False) seq#
計算給定張量的長度為 的組合。當 with_replacement 設定為 False 時,其行為類似於 Python 的 itertools.combinations;當 with_replacement 設定為 True 時,其行為類似於 itertools.combinations_with_replacement。
- 引數
- 返回
一個等效於將所有輸入張量轉換為列表,在這些列表上執行 itertools.combinations 或 itertools.combinations_with_replacement,最後將結果列表轉換為張量的張量。
- 返回型別
示例
>>> a = [1, 2, 3] >>> list(itertools.combinations(a, r=2)) [(1, 2), (1, 3), (2, 3)] >>> list(itertools.combinations(a, r=3)) [(1, 2, 3)] >>> list(itertools.combinations_with_replacement(a, r=2)) [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)] >>> tensor_a = torch.tensor(a) >>> torch.combinations(tensor_a) tensor([[1, 2], [1, 3], [2, 3]]) >>> torch.combinations(tensor_a, r=3) tensor([[1, 2, 3]]) >>> torch.combinations(tensor_a, with_replacement=True) tensor([[1, 1], [1, 2], [1, 3], [2, 2], [2, 3], [3, 3]])