torch.nanmean#
- torch.nanmean(input, dim=None, keepdim=False, *, dtype=None, out=None) Tensor#
計算指定維度上所有非 NaN 元素的平均值。輸入必須是浮點型或複數型。
當
input張量中沒有 NaN 值時,此函式與torch.mean()相同。在存在 NaN 的情況下,torch.mean()會將 NaN 傳播到輸出,而torch.nanmean()則會忽略 NaN 值(torch.nanmean(a) 等同於 torch.mean(a[~a.isnan()]))。如果
keepdim為True,則輸出張量的大小與input相同,只有在dim維度上大小為 1。否則,dim將被擠壓(參見torch.squeeze()),導致輸出張量維度減少 1(或len(dim))個。- 引數
- 關鍵字引數
dtype (
torch.dtype, 可選) – 返回張量的期望資料型別。如果指定,則在執行操作之前將輸入張量轉換為dtype。這對於防止資料型別溢位很有用。預設為 None。out (Tensor, optional) – 輸出張量。
另請參閱
torch.mean()計算平均值,會傳播 NaN。示例
>>> x = torch.tensor([[torch.nan, 1, 2], [1, 2, 3]]) >>> x.mean() tensor(nan) >>> x.nanmean() tensor(1.8000) >>> x.mean(dim=0) tensor([ nan, 1.5000, 2.5000]) >>> x.nanmean(dim=0) tensor([1.0000, 1.5000, 2.5000]) # If all elements in the reduced dimensions are NaN then the result is NaN >>> torch.tensor([torch.nan]).nanmean() tensor(nan)