torch.gather#
- torch.gather(input, dim, index, *, sparse_grad=False, out=None) Tensor#
沿由 dim 指定的軸收集值。
對於一個 3-D 張量,輸出由
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
input和index必須具有相同的維度數。同時,對於所有維度d != dim,要求index.size(d) <= input.size(d)。out的形狀將與index相同。請注意,input和index不會進行廣播。當index為空時,我們始終返回一個具有相同形狀的空輸出,不再進行進一步的錯誤檢查。- 引數
- 關鍵字引數
示例
>>> t = torch.tensor([[1, 2], [3, 4]]) >>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]])) tensor([[ 1, 1], [ 4, 3]])