評價此頁

torch.nanmean#

torch.nanmean(input, dim=None, keepdim=False, *, dtype=None, out=None) Tensor#

計算指定維度上所有非 NaN 元素的平均值。輸入必須是浮點型或複數型。

input 張量中沒有 NaN 值時,此函式與 torch.mean() 相同。在存在 NaN 的情況下,torch.mean() 會將 NaN 傳播到輸出,而 torch.nanmean() 則會忽略 NaN 值(torch.nanmean(a) 等同於 torch.mean(a[~a.isnan()]))。

如果 keepdimTrue,則輸出張量的大小與 input 相同,只有在 dim 維度上大小為 1。否則,dim 將被擠壓(參見 torch.squeeze()),導致輸出張量維度減少 1(或 len(dim))個。

引數
  • input (Tensor) – 輸入張量,必須是浮點型或複數型

  • dim (inttuple of ints, optional) – 要規約的維度或維度。如果為 None,則規約所有維度。

  • keepdim (bool, optional) – 輸出張量是否保留 dim。預設為 False

關鍵字引數
  • dtype (torch.dtype, 可選) – 返回張量的期望資料型別。如果指定,則在執行操作之前將輸入張量轉換為 dtype。這對於防止資料型別溢位很有用。預設為 None。

  • out (Tensor, optional) – 輸出張量。

另請參閱

torch.mean() 計算平均值,會傳播 NaN

示例

>>> x = torch.tensor([[torch.nan, 1, 2], [1, 2, 3]])
>>> x.mean()
tensor(nan)
>>> x.nanmean()
tensor(1.8000)
>>> x.mean(dim=0)
tensor([   nan, 1.5000, 2.5000])
>>> x.nanmean(dim=0)
tensor([1.0000, 1.5000, 2.5000])

# If all elements in the reduced dimensions are NaN then the result is NaN
>>> torch.tensor([torch.nan]).nanmean()
tensor(nan)