快捷方式

distance_loss

class torchrl.objectives.distance_loss(v1: TensorLike, v2: TensorLike, loss_function: str, strict_shape: bool = True)[原始碼]

計算兩個張量之間的距離損失。

引數:
  • v1 (Tensor | TensorDict) – 一個形狀與 v2 相容的張量或 tensordict。

  • v2 (Tensor | TensorDict) – 一個形狀與 v1 相容的張量或 tensordict。

  • loss_function (str) – “l2”、“l1”或“smooth_l1”之一,表示要使用的損失函式。

  • strict_shape (bool) – 如果為 False,則允許 v1 和 v2 具有不同的形狀。預設為 True

返回:

一個形狀為 v1.view_as(v2) 或 v2.view_as(v1) 的張量或 tensordict

其值等於兩者之間的距離損失。

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源