torch.diagflat#
- torch.diagflat(input, offset=0) Tensor#
如果
input是一個向量(1-D 張量),則返回一個 2-D 方形張量,其中input的元素作為對角線。如果
input是一個具有多個維度的張量,則返回一個 2-D 張量,其對角線元素等於一個展平的input。
引數
offset控制要考慮的對角線如果
offset= 0,則為主對角線。如果
offset> 0,則位於主對角線之上。如果
offset< 0,則位於主對角線之下。
示例
>>> a = torch.randn(3) >>> a tensor([-0.2956, -0.9068, 0.1695]) >>> torch.diagflat(a) tensor([[-0.2956, 0.0000, 0.0000], [ 0.0000, -0.9068, 0.0000], [ 0.0000, 0.0000, 0.1695]]) >>> torch.diagflat(a, 1) tensor([[ 0.0000, -0.2956, 0.0000, 0.0000], [ 0.0000, 0.0000, -0.9068, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.1695], [ 0.0000, 0.0000, 0.0000, 0.0000]]) >>> a = torch.randn(2, 2) >>> a tensor([[ 0.2094, -0.3018], [-0.1516, 1.9342]]) >>> torch.diagflat(a) tensor([[ 0.2094, 0.0000, 0.0000, 0.0000], [ 0.0000, -0.3018, 0.0000, 0.0000], [ 0.0000, 0.0000, -0.1516, 0.0000], [ 0.0000, 0.0000, 0.0000, 1.9342]])