torch.nn.functional.cosine_similarity#
- torch.nn.functional.cosine_similarity(x1, x2, dim=1, eps=1e-8) Tensor#
返回
x1和x2之間的餘弦相似度,沿 dim 計算。x1和x2必須可廣播到公共形狀。dim指的是此公共形狀中的維度。輸出的dim維度被擠壓(參見torch.squeeze()),導致輸出張量維度少 1。支援 型別提升。
- 引數
示例
>>> input1 = torch.randn(100, 128) >>> input2 = torch.randn(100, 128) >>> output = F.cosine_similarity(input1, input2) >>> print(output)