快捷方式

SafeProbabilisticModule

class torchrl.modules.tensordict_module.SafeProbabilisticModule(*args, **kwargs)[原始碼]

tensordict.nn.ProbabilisticTensorDictModule 的子類,它接受一個 TensorSpec 作為引數來控制輸出域。

SafeProbabilisticModule 是一個非引數模組,封裝了一個機率分佈構造器。它從輸入的 TensorDict 中讀取分佈引數,使用指定的 in_keys,並輸出該分佈的一個(寬泛意義上的)樣本。

輸出“樣本”是根據一個規則生成的,該規則由輸入引數 default_interaction_type 和全域性函式 interaction_type() 指定。

SafeProbabilisticModule 可用於構造分佈(透過 get_dist() 方法)和/或從中取樣(透過對模組的常規 __call__() 呼叫)。

一個 SafeProbabilisticModule 例項具有兩個主要特性:

  • 它從 TensorDict 物件讀取並寫入資料;

  • 它使用一個真實的對映 R^n -> R^m 來建立一個 R^d 中的分佈,從中可以取樣或計算值。

當呼叫 __call__()forward() 方法時,會建立一個分佈,並計算一個值(根據 interaction_type 的值,可以使用 'dist.mean'、'dist.mode'、'dist.median' 屬性,以及 'dist.rsample'、'dist.sample' 方法)。如果提供的 TensorDict 已包含所有期望的鍵值對,則會跳過取樣步驟。

預設情況下,SafeProbabilisticModule 的分佈類是 Delta 分佈,這使得 SafeProbabilisticModule 成為一個簡單的確定性對映函式包裝器。

此類與 tensordict.nn.ProbabilisticTensorDictModule 的區別在於,它接受一個 spec 關鍵字引數,該引數可用於控制樣本是否屬於分佈。 safe 關鍵字引數控制樣本值是否應根據 spec 進行檢查。

引數:
  • in_keys (NestedKey | List[NestedKey] | Dict[str, NestedKey]) – 將從輸入的 TensorDict 讀取並用於構建分佈的鍵。重要的是,如果它是 NestedKey 列表或 NestedKey,這些鍵的葉子(最後一個元素)必須與所關注的分佈類使用的關鍵字匹配,例如,對於 Normal 分佈,關鍵字為 "loc""scale",依此類推。如果 in_keys 是一個字典,鍵是分佈的鍵,值是 tensordict 中將匹配相應分佈鍵的鍵。

  • out_keys (NestedKey | List[NestedKey] | None) – 將寫入取樣值的鍵。重要的是,如果這些鍵在輸入的 TensorDict 中找到,則會跳過取樣步驟。

  • spec (TensorSpec) – 第一個輸出張量的 spec。在呼叫 td_module.random() 生成目標空間中的隨機值時使用。

關鍵字引數:
  • safe (bool, optional) – 如果為 True,則樣本的值將根據輸入 spec 進行檢查。由於探索策略或數值下溢/上溢問題,可能會出現域外取樣。與 spec 引數一樣,此檢查僅針對分佈樣本進行,而不是輸入模組返回的其他張量。如果樣本超出範圍,則使用 TensorSpec.project 方法將其投影回所需空間。預設為 False

  • default_interaction_type (InteractionType, optional) –

    僅關鍵字引數。用於檢索輸出值的預設方法。應為 InteractionType 中的一個:MODE、MEDIAN、MEAN 或 RANDOM(在這種情況下,值將從分佈中隨機取樣)。預設為 MODE。

    注意

    當繪製樣本時,ProbabilisticTensorDictModule 例項將首先查詢由全域性函式 interaction_type() 指定的互動模式。如果此函式返回 None(其預設值),則將使用 ProbabilisticTDModule 例項的 default_interaction_type。請注意,DataCollectorBase 例項將預設使用 set_interaction_type 設定為 tensordict.nn.InteractionType.RANDOM

    注意

    在某些情況下,可能無法透過相應的屬性直接獲取 mode、median 或 mean 值。為解決此問題,ProbabilisticTensorDictModule 將首先嚐試透過呼叫 get_mode()get_median()get_mean() 來獲取值(如果方法存在)。

  • distribution_class (Type or Callable[[Any], Distribution], optional) –

    僅關鍵字引數。用於取樣的 torch.distributions.Distribution 類。預設為 Delta

    注意

    如果 distribution_class 是 CompositeDistribution 型別,則 out_keys 可以直接從該類的 distribution_kwargs 關鍵字引數中透過 "distribution_map""name_map" 推斷出來,在這種情況下 out_keys 是可選的。

  • distribution_kwargs (dict, optional) –

    僅關鍵字引數。要傳遞給分佈的關鍵字引數對。

    注意

    如果您的 kwargs 包含您希望與模組一起傳輸到裝置的張量,或者在呼叫 module.to(dtype) 時其 dtype 應被修改的張量,您可以將 kwargs 包裝在 TensorDictParams 中以自動完成此操作。

  • return_log_prob (bool, optional) – 僅關鍵字引數。如果為 True,則分佈樣本的對數機率將以 log_prob_key 的鍵寫入 tensordict。預設為 False

  • log_prob_keys (List[NestedKey], optional) –

    如果 return_log_prob=True,則寫入 log_prob 的鍵。預設為 ‘<sample_key_name>_log_prob’,其中 <sample_key_name>out_keys 的每個元素。

    注意

    這僅在 composite_lp_aggregate() 設定為 False 時可用。

  • log_prob_key (NestedKey, optional) –

    如果 return_log_prob=True,則寫入 log_prob 的鍵。預設為 ‘sample_log_prob’(當 composite_lp_aggregate() 設定為 True 時)或 ‘<sample_key_name>_log_prob’(否則)。

    注意

    當有多個樣本時,這僅在 composite_lp_aggregate() 設定為 True 時可用。

  • cache_dist (bool, optional) – 僅關鍵字引數。實驗性質:如果為 True,則分佈的引數(即模組的輸出)將與樣本一起寫入 tensordict。這些引數可用於以後重新計算原始分佈(例如,在 PPO 中計算用於取樣動作的分佈與更新後的分佈之間的散度)。預設為 False

  • n_empirical_estimate (int, optional) – 僅關鍵字引數。在均值不可用時計算經驗均值的樣本數量。預設為 1000。

警告

執行檢查需要時間!使用 safe=True 將保證樣本在 spec 邊界內,這依賴於 project() 中編碼的一些啟發式方法,但這需要檢查值是否在 spec 空間內,這將產生一些開銷。

另請參閱

tensordict.nn.CompositeDistribution(複合分佈)可用於建立多頭策略。

示例

>>> from torchrl.modules import SafeProbabilisticModule
>>> from torchrl.data import Bounded
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import InteractionType
>>> mod = SafeProbabilisticModule(
...     in_keys=["loc", "scale"],
...     out_keys=["action"],
...     distribution_class=torch.distributions.Normal,
...     safe=True,
...     spec=Bounded(low=-1, high=1, shape=()),
...     default_interaction_type=InteractionType.RANDOM
... )
>>> _ = torch.manual_seed(0)
>>> data = TensorDict(
...     loc=torch.zeros(10, requires_grad=True),
...     scale=torch.full((10,), 10.0),
...     batch_size=(10,))
>>> data = mod(data)
>>> print(data["action"]) # All actions are within bound
tensor([ 1., -1., -1.,  1., -1., -1.,  1.,  1., -1., -1.],
       grad_fn=<ClampBackward0>)
>>> data["action"].mean().backward()
>>> print(data["loc"].grad) # clamp anihilates gradients
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
random(tensordict: TensorDictBase) TensorDictBase[原始碼]

獨立於任何輸入,在目標空間中隨機取樣一個元素。

如果存在多個輸出鍵,則只有第一個鍵會寫入輸入的 tensordict 中。

引數:

tensordict (TensorDictBase) – 應將輸出值寫入的 tensordict。

返回:

包含輸出鍵的新/更新值的原始 tensordict。

random_sample(tensordict: TensorDictBase) TensorDictBase[原始碼]

請參閱 SafeModule.random(...)

文件

訪問全面的 PyTorch 開發者文件

檢視文件

教程

為初學者和高階開發者提供深入的教程

檢視教程

資源

查詢開發資源並讓您的問題得到解答

檢視資源