ProbabilisticTensorDictModule¶
- class tensordict.nn.ProbabilisticTensorDictModule(*args, **kwargs)¶
一個機率 TD 模組。
ProbabilisticTensorDictModule 是一個非引數模組,用於嵌入一個機率分佈構造器。它使用指定的 in_keys 從輸入的 TensorDict 中讀取分佈引數,並輸出分佈的(廣義上的)樣本。
輸出的“樣本”是根據一個規則生成的,該規則由輸入引數
default_interaction_type和全域性函式interaction_type()指定。ProbabilisticTensorDictModule 可用於構造分佈(透過
get_dist()方法)和/或從該分佈進行取樣(透過對模組進行常規的__call__()呼叫)。一個 ProbabilisticTensorDictModule 例項具有兩個主要特性:
它讀寫 TensorDict 物件;
它使用一個實數對映 R^n -> R^m 來建立 R^d 中的一個分佈,從中可以取樣或計算值。
當呼叫
__call__()和forward()方法時,會建立一個分佈,並計算一個值(根據interaction_type的值,可以使用 ‘dist.mean’、‘dist.mode’、‘dist.median’ 屬性,以及 ‘dist.rsample’、‘dist.sample’ 方法)。如果提供的 TensorDict 已包含所有期望的鍵值對,則會跳過取樣步驟。預設情況下,ProbabilisticTensorDictModule 的分佈類是
Delta分佈,這使得 ProbabilisticTensorDictModule 成為一個簡單的確定性對映函式包裝器。- 引數:
in_keys (NestedKey | List[NestedKey] | Dict[str, NestedKey]) – 將從輸入 TensorDict 讀取並用於構建分佈的鍵。重要的是,如果它是 NestedKey 列表或 NestedKey,則這些鍵的葉子(最後一個元素)必須與感興趣的分佈類使用的關鍵字匹配,例如,對於
Normal分佈,需要匹配"loc"和"scale",依此類推。如果 in_keys 是一個字典,鍵是分佈的鍵,值是 tensordict 中將與相應的分佈鍵匹配的鍵。out_keys (NestedKey | List[NestedKey] | None) – 取樣值將被寫入的鍵。重要的是,如果這些鍵存在於輸入 TensorDict 中,則會跳過取樣步驟。
- 關鍵字引數:
default_interaction_type (InteractionType, optional) –
僅關鍵字引數。用於檢索輸出值的預設方法。應為 InteractionType 中的一個:MODE、MEDIAN、MEAN 或 RANDOM(在這種情況下,值將從分佈中隨機取樣)。預設為 MODE。
注意
當繪製樣本時,
ProbabilisticTensorDictModule例項將首先查詢由全域性函式interaction_type()指定的互動模式。如果此函式返回 None(其預設值),則將使用 default_interaction_type 屬性ProbabilisticTDModule例項。注意
在某些情況下,模式、中位數或均值可能無法透過相應的屬性輕鬆獲得。為了緩解這個問題,
ProbabilisticTensorDictModule將首先嚐試透過呼叫get_mode()、get_median()或get_mean()來獲取值(如果方法存在)。distribution_class (Type or Callable[[Any], Distribution], optional) –
僅關鍵字引數。用於取樣的
torch.distributions.Distribution類。預設為Delta。注意
如果分佈類是
CompositeDistribution型別,則out_keys可以直接從該類的distribution_kwargs關鍵字引數中提供的"distribution_map"或"name_map"鍵中推斷出來,從而使out_keys在這種情況下成為可選。distribution_kwargs (dict, optional) –
僅關鍵字引數。將傳遞給分佈的關鍵字引數對。
注意
如果您的 kwargs 包含您希望與模組一起傳輸到裝置上的張量,或者在呼叫 module.to(dtype) 時應修改其 dtype 的張量,您可以將 kwargs 包裝在
TensorDictParams中以自動執行此操作。return_log_prob (bool, optional) – 僅關鍵字引數。如果為
True,則分佈樣本的對數機率將寫入 tensordict,鍵為 log_prob_key。預設為False。log_prob_keys (List[NestedKey], optional) –
當
return_log_prob=True時,寫入 log_prob 的鍵。預設為 ‘<sample_key_name>_log_prob’,其中 <sample_key_name> 是每個out_keys。注意
當
composite_lp_aggregate()設定為False時,這僅在以下情況下可用。log_prob_key (NestedKey, optional) –
當
return_log_prob=True時,寫入 log_prob 的鍵。當composite_lp_aggregate()設定為True時,預設為 ‘sample_log_prob’,否則為 ‘<sample_key_name>_log_prob’。注意
當有多個樣本時,這僅在
composite_lp_aggregate()設定為True時可用。cache_dist (bool, optional) – 僅關鍵字引數。實驗性:如果為
True,則分佈的引數(即模組的輸出)將與樣本一起寫入 tensordict。這些引數可用於稍後重新計算原始分佈(例如,在 PPO 中計算用於取樣動作的分佈與更新後的分佈之間的散度)。預設為False。n_empirical_estimate (int, optional) – 僅關鍵字引數。在經驗均值不可用時,用於計算經驗均值的樣本數量。預設為 1000。
示例
>>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import ( ... ProbabilisticTensorDictModule, ... ProbabilisticTensorDictSequential, ... TensorDictModule, ... ) >>> from tensordict.nn.distributions import NormalParamExtractor >>> from tensordict.nn.functional_modules import make_functional >>> from torch.distributions import Normal, Independent >>> td = TensorDict( ... {"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3] ... ) >>> net = torch.nn.GRUCell(4, 8) >>> module = TensorDictModule( ... net, in_keys=["input", "hidden"], out_keys=["params"] ... ) >>> normal_params = TensorDictModule( ... NormalParamExtractor(), in_keys=["params"], out_keys=["loc", "scale"] ... ) >>> def IndepNormal(**kwargs): ... return Independent(Normal(**kwargs), 1) >>> prob_module = ProbabilisticTensorDictModule( ... in_keys=["loc", "scale"], ... out_keys=["action"], ... distribution_class=IndepNormal, ... return_log_prob=True, ... ) >>> td_module = ProbabilisticTensorDictSequential( ... module, normal_params, prob_module ... ) >>> params = TensorDict.from_module(td_module) >>> with params.to_module(td_module): ... _ = td_module(td) >>> print(td) TensorDict( fields={ action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), hidden: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), params: Tensor(shape=torch.Size([3, 8]), device=cpu, dtype=torch.float32, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) >>> with params.to_module(td_module): ... dist = td_module.get_dist(td) >>> print(dist) Independent(Normal(loc: torch.Size([3, 4]), scale: torch.Size([3, 4])), 1) >>> # we can also apply the module to the TensorDict with vmap >>> from torch import vmap >>> params = params.expand(4) >>> def func(td, params): ... with params.to_module(td_module): ... return td_module(td) >>> td_vmap = vmap(func, (None, 0))(td, params) >>> print(td_vmap) TensorDict( fields={ action: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), hidden: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), params: Tensor(shape=torch.Size([4, 3, 8]), device=cpu, dtype=torch.float32, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 3]), device=None, is_shared=False)
- build_dist_from_params(tensordict: TensorDictBase) Distribution¶
使用輸入 tensordict 中提供的引數建立一個
torch.distribution.Distribution例項。- 引數:
tensordict (TensorDictBase) – 包含分佈引數的輸入 tensordict。
- 返回:
一個從輸入 tensordict 建立的
torch.distribution.Distribution例項。- 丟擲:
TypeError – 如果輸入 tensordict 與分佈關鍵字不匹配。
- forward(tensordict: TensorDictBase = None, tensordict_out: tensordict.base.TensorDictBase | None = None, _requires_sample: bool = True) TensorDictBase¶
定義每次呼叫時執行的計算。
所有子類都應重寫此方法。
注意
儘管前向傳播的實現需要在此函式中定義,但您應該在之後呼叫
Module例項而不是此函式,因為前者會處理註冊的鉤子,而後者則會靜默忽略它們。
- get_dist(tensordict: TensorDictBase) Distribution¶
使用輸入 tensordict 中提供的引數建立一個
torch.distribution.Distribution例項。- 引數:
tensordict (TensorDictBase) – 包含分佈引數的輸入 tensordict。
- 返回:
一個從輸入 tensordict 建立的
torch.distribution.Distribution例項。- 丟擲:
TypeError – 如果輸入 tensordict 與分佈關鍵字不匹配。
- log_prob(tensordict, *, dist: Optional[Distribution] = None)¶
計算分佈樣本的對數機率。
- 引數:
tensordict (TensorDictBase) – 包含分佈引數的輸入 tensordict。
dist (torch.distributions.Distribution, optional) – 分佈例項。預設為
None。如果為None,則將使用 get_dist 方法計算分佈。
- 返回:
表示分佈樣本對數機率的張量。