評價此頁

torch.median#

torch.median(input) Tensor#

返回 input 中值的中位數。

注意

input 張量包含偶數個元素時,中位數不是唯一的。在這種情況下,將返回兩個中位數中較低的一個。要計算兩個中位數的平均值,請改用 q=0.5torch.quantile()

警告

此函式產生確定性的(子)梯度,與 median(dim=0) 不同。

引數

input (Tensor) – 輸入張量。

示例

>>> a = torch.randn(1, 3)
>>> a
tensor([[ 1.5219, -1.5212,  0.2202]])
>>> torch.median(a)
tensor(0.2202)
torch.median(input, dim=-1, keepdim=False, *, out=None)

返回一個命名元組 (values, indices),其中 values 包含 inputdim 維度上每行的中位數,而 indices 包含在 dim 維度上找到的中位數值的索引。

預設情況下,diminput 張量的最後一個維度。

如果 keepdimTrue,則輸出張量的大小與 input 相同,除了在 dim 維度上,它們的大小為 1。否則,dim 將被擠壓(參見 torch.squeeze()),導致輸出張量比 input 少一個維度。

注意

input 張量在 dim 維度上包含偶數個元素時,中位數不是唯一的。在這種情況下,將返回兩個中位數中較低的一個。要計算 input 中兩個中位數的平均值,請改用 q=0.5torch.quantile()

警告

indices 不一定包含找到的每個中位數值的第一次出現,除非該值是唯一的。具體的實現細節取決於裝置。通常不要期望在 CPU 和 GPU 上執行得到相同的結果。出於同樣的原因,不要期望梯度是確定的。

引數
  • input (Tensor) – 輸入張量。

  • dim (int, optional) – 要約簡的維度。如果為 None,則約簡所有維度。

  • keepdim (bool, optional) – 輸出張量是否保留 dim。預設為 False

關鍵字引數

out ((Tensor, Tensor), optional) – 第一個張量將填充中位數,第二個張量(必須為 long 型別)將填充 inputdim 維度上的索引。

示例

>>> a = torch.randn(4, 5)
>>> a
tensor([[ 0.2505, -0.3982, -0.9948,  0.3518, -1.3131],
        [ 0.3180, -0.6993,  1.0436,  0.0438,  0.2270],
        [-0.2751,  0.7303,  0.2192,  0.3321,  0.2488],
        [ 1.0778, -1.9510,  0.7048,  0.4742, -0.7125]])
>>> torch.median(a, 1)
torch.return_types.median(values=tensor([-0.3982,  0.2270,  0.2488,  0.4742]), indices=tensor([1, 4, 4, 3]))