評價此頁

torch.flatten#

torch.flatten(input, start_dim=0, end_dim=-1) Tensor#

透過重塑 input 為一維張量來展平它。如果傳入了 start_dimend_dim,則僅展平從 start_dim 開始、到 end_dim 結束的維度。 input 中的元素順序保持不變。

與總是複製輸入資料的 NumPy 的 flatten 不同,此函式可能會返回原始物件、檢視或副本。如果沒有任何維度被展平,則返回原始物件 input。否則,如果 input 可以被視為展平後的形狀,則返回該檢視。最後,只有當 input 無法被視為展平後的形狀時,才會複製 input 的資料。有關何時返回檢視的詳細資訊,請參閱 torch.Tensor.view()

注意

展平零維張量將返回一個一維檢視。

引數
  • input (Tensor) – 輸入張量。

  • start_dim (int) – 要展平的第一個維度

  • end_dim (int) – 要展平的最後一個維度

示例

>>> t = torch.tensor([[[1, 2],
...                    [3, 4]],
...                   [[5, 6],
...                    [7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])