評價此頁

torch.segment_reduce#

torch.segment_reduce(data: Tensor, reduce: str, *, lengths: Tensor | None = None, indices: Tensor | None = None, offsets: Tensor | None = None, axis: _int = 0, unsafe: _bool = False, initial: Number | _complex | None = None) Tensor#

在指定軸上對輸入張量執行分段歸約操作。

引數
  • data (Tensor) – 將執行分段歸約操作的輸入張量。

  • reduce (str) – 歸約操作的型別。支援的值包括 summeanmaxminprod

關鍵字引數
  • lengths (Tensor, optional) – 每個段的長度。預設為 None

  • offsets (Tensor, optional) – 每個段的偏移量。預設為 None

  • axis (int, optional) – 用於執行歸約的軸。預設為 0

  • unsafe (bool, optional) – 如果為 True,則跳過驗證。預設為 False

  • initial (Number, optional) – 歸約操作的初始值。預設為 None

示例

>>> data = torch.tensor([[1, 2, 3, 4],[5, 6, 7, 8],[9, 10, 11, 12]], dtype=torch.float32, device='cuda')
>>> lengths = torch.tensor([2, 1], device='cuda')
>>> torch.segment_reduce(data, 'max', lengths=lengths)
tensor([[ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.]], device='cuda:0')