快捷方式

tensordict.nn.add_custom_mapping

tensordict.nn.add_custom_mapping(name: str, mapping: Callable[[Tensor], Tensor])

在對映類中新增自定義對映。

引數:
  • name (str) – 對映的名稱。

  • mapping (callable) – 一個可呼叫物件,它接收一個張量作為輸入,並輸出一個具有相同形狀的張量。

示例

>>> from tensordict.nn import add_custom_mapping, NormalParamExtractor
>>> add_custom_mapping("my_mapping", lambda x: torch.zeros_like(x))
>>> npe = NormalParamExtractor(scale_mapping="my_mapping", scale_lb=0.0)
>>> assert (npe(torch.randn(10))[1] == torch.zeros(5)).all()

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源