評價此頁

torch.diag#

torch.diag(input, diagonal=0, *, out=None) Tensor#
  • 如果 input 是一個向量(1-D 張量),則返回一個 2-D 方形張量,其中 input 的元素作為對角線。

  • 如果 input 是一個矩陣(2-D 張量),則返回一個 1-D 張量,其中包含 input 的對角線元素。

diagonal 引數控制要考慮哪個對角線

  • 如果 diagonal = 0,則為主對角線。

  • 如果 diagonal > 0,則為上對角線。

  • 如果 diagonal < 0,則為下對角線。

引數
  • input (Tensor) – 輸入張量。

  • diagonal (int, optional) – 要考慮的對角線

關鍵字引數

out (Tensor, optional) – 輸出張量。

另請參閱

torch.diagonal() 始終返回其輸入的對角線。

torch.diagflat() 始終構造一個對角線元素由輸入指定的張量。

示例

獲取輸入向量作為對角線的方陣

>>> a = torch.randn(3)
>>> a
tensor([ 0.5950,-0.0872, 2.3298])
>>> torch.diag(a)
tensor([[ 0.5950, 0.0000, 0.0000],
        [ 0.0000,-0.0872, 0.0000],
        [ 0.0000, 0.0000, 2.3298]])
>>> torch.diag(a, 1)
tensor([[ 0.0000, 0.5950, 0.0000, 0.0000],
        [ 0.0000, 0.0000,-0.0872, 0.0000],
        [ 0.0000, 0.0000, 0.0000, 2.3298],
        [ 0.0000, 0.0000, 0.0000, 0.0000]])

獲取給定矩陣的第 k 條對角線

>>> a = torch.randn(3, 3)
>>> a
tensor([[-0.4264, 0.0255,-0.1064],
        [ 0.8795,-0.2429, 0.1374],
        [ 0.1029,-0.6482,-1.6300]])
>>> torch.diag(a, 0)
tensor([-0.4264,-0.2429,-1.6300])
>>> torch.diag(a, 1)
tensor([ 0.0255, 0.1374])