Unflatten#
- class torch.nn.Unflatten(dim, unflattened_size)[source]#
將張量的某個維度解扁平化,將其擴充套件到所需的形狀。與
Sequential一起使用。dim指定要解扁平化的輸入張量的維度,當分別使用 Tensor 或 NamedTensor 時,它可以是 int 或 str。unflattened_size是張量解扁平化維度的新的形狀,對於 Tensor 輸入,它可以是 tuple of ints 或 list of ints 或 torch.Size;對於 NamedTensor 輸入,它可以是 NamedShape(由 (name, size) 元組組成的元組)。
- 形狀
輸入: ,其中 是維度
dim的大小, 表示任意數量的維度(包括零個)。輸出: ,其中 =
unflattened_size且 .
- 引數
unflattened_size (Union[torch.Size, Tuple, List, NamedShape]) – 反展平維度的新的形狀
示例
>>> input = torch.randn(2, 50) >>> # With tuple of ints >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, (2, 5, 5)) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With torch.Size >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, torch.Size([2, 5, 5])) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # With namedshape (tuple of tuples) >>> input = torch.randn(2, 50, names=("N", "features")) >>> unflatten = nn.Unflatten("features", (("C", 2), ("H", 5), ("W", 5))) >>> output = unflatten(input) >>> output.size() torch.Size([2, 2, 5, 5])