torch.cat#
- torch.cat(tensors, dim=0, *, out=None) Tensor#
將給定的張量序列
tensors在指定維度上進行拼接。所有張量必須具有相同的形狀(除了拼接維度),或者是一個大小為(0,)的一維空張量。torch.cat()可以看作是torch.split()和torch.chunk()的逆向操作。torch.cat()可以透過示例來最好地理解。另請參閱
torch.stack()沿著新維度拼接給定的序列。- 引數
tensors (Sequence of Tensors) – 提供的非空張量,除了拼接維度外,必須具有相同的形狀。
dim (int, optional) – 張量被拼接的維度
- 關鍵字引數
out (Tensor, optional) – 輸出張量。
示例
>>> x = torch.randn(2, 3) >>> x tensor([[ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497]]) >>> torch.cat((x, x, x), 0) tensor([[ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497], [ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497], [ 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497]]) >>> torch.cat((x, x, x), 1) tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614], [-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497]])