torch.flatten#
- torch.flatten(input, start_dim=0, end_dim=-1) Tensor#
透過重塑
input為一維張量來展平它。如果傳入了start_dim或end_dim,則僅展平從start_dim開始、到end_dim結束的維度。input中的元素順序保持不變。與總是複製輸入資料的 NumPy 的 flatten 不同,此函式可能會返回原始物件、檢視或副本。如果沒有任何維度被展平,則返回原始物件
input。否則,如果 input 可以被視為展平後的形狀,則返回該檢視。最後,只有當 input 無法被視為展平後的形狀時,才會複製 input 的資料。有關何時返回檢視的詳細資訊,請參閱torch.Tensor.view()。注意
展平零維張量將返回一個一維檢視。
示例
>>> t = torch.tensor([[[1, 2], ... [3, 4]], ... [[5, 6], ... [7, 8]]]) >>> torch.flatten(t) tensor([1, 2, 3, 4, 5, 6, 7, 8]) >>> torch.flatten(t, start_dim=1) tensor([[1, 2, 3, 4], [5, 6, 7, 8]])