機率分佈 - torch.distributions#
創建於:2017年10月19日 | 最後更新於:2025年6月13日
distributions 包包含可引數化的機率分佈和取樣函式。這允許構建用於最佳化的隨機計算圖和隨機梯度估計器。該包通常遵循 TensorFlow Distributions 包的設計。
無法直接反向傳播隨機樣本。但是,有兩種主要方法可以建立可反向傳播的代理函式:得分函式估計器/似然比估計器/REINFORCE 和路徑wise導數估計器。REINFORCE 通常被視為強化學習中策略梯度方法的基礎,而路徑wise導數估計器通常出現在變分自編碼器中的重引數化技巧中。雖然得分函式僅需要樣本 的值,路徑wise導數需要導數 。接下來的幾節將在強化學習示例中討論這兩種方法。更多詳情請參閱 使用隨機計算圖進行梯度估計。
得分函式#
當機率密度函式相對於其引數可微時,我們只需要 sample() 和 log_prob() 來實現 REINFORCE。
其中 是引數, 是學習率, 是獎勵, 是在策略 下,狀態 下采取動作 的機率。
實際上,我們將從網路的輸出中取樣一個動作,在環境中應用該動作,然後使用 log_prob 來構建等效的損失函式。請注意,我們使用的是負號,因為最佳化器使用梯度下降,而上述規則假定梯度上升。對於分類策略,實現 REINFORCE 的程式碼如下:
probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()
路徑wise導數#
實現這些隨機/策略梯度的另一種方法是使用 rsample() 方法中的重引數化技巧,其中引數化隨機變數可以透過無引數隨機變數的引數化確定性函式來構造。因此,重引數化樣本變得可微分。實現路徑wise導數的程式碼如下:
params = policy_network(state)
m = Normal(*params)
# Any distribution with .has_rsample == True could work based on the application
action = m.rsample()
next_state, reward = env.step(action) # Assuming that reward is differentiable
loss = -reward
loss.backward()
Distribution#
- class torch.distributions.distribution.Distribution(batch_shape=torch.Size([]), event_shape=torch.Size([]), validate_args=None)[source]#
Bases:
objectDistribution 是機率分佈的抽象基類。
- 引數
batch_shape (torch.Size) – 引數被分批處理的形狀。
event_shape (torch.Size) – 單個樣本(不帶批處理)的形狀。
validate_args (bool, optional) – 是否驗證引數。預設值:None。
- property arg_constraints: dict[str, torch.distributions.constraints.Constraint]#
返回一個字典,將引數名稱對映到
Constraint物件,這些物件應由此分佈的每個引數滿足。不在此字典中的引數(非張量)無需考慮。
- enumerate_support(expand=True)[source]#
返回一個包含離散分佈所有支援值的張量。結果將列舉維度 0,因此結果的形狀為 (基數,) + batch_shape + event_shape(其中單變數分佈的 event_shape = ())。
請注意,這會以同步方式列舉所有批處理張量,例如 [[0, 0], [1, 1], …]。當 expand=False 時,列舉沿 dim 0 進行,但其餘批處理維度為單例維度,即 [[0], [1], .. 。
要迭代完整的笛卡爾積,請使用 itertools.product(m.enumerate_support())。
- expand(batch_shape, _instance=None)[source]#
返回一個新的分佈例項(或填充由派生類提供的現有例項),其批處理維度已擴充套件到 batch_shape。此方法呼叫分佈引數上的
expand。因此,這不會為擴充套件的分佈例項分配新記憶體。此外,這不會在例項首次建立時重複 __init__.py 中的任何引數檢查或廣播。- 引數
batch_shape (torch.Size) – 所需的擴充套件大小。
_instance – 由需要重寫 .expand 的子類提供的新的例項。
- 返回
具有批處理維度擴充套件到 batch_size 的新分佈例項。
- rsample(sample_shape=torch.Size([]))[source]#
生成 sample_shape 形狀的重引數化樣本,如果分佈引數是批處理的,則生成 sample_shape 形狀的重引數化樣本批次。
- 返回型別
- sample(sample_shape=torch.Size([]))[source]#
生成 sample_shape 形狀的樣本,如果分佈引數是批處理的,則生成 sample_shape 形狀的樣本批次。
- 返回型別
- static set_default_validate_args(value)[source]#
設定是否啟用或停用驗證。
預設行為模仿 Python 的
assert語句:預設情況下啟用驗證,但如果 Python 在最佳化模式下執行(透過python -O),則停用驗證。驗證可能很昂貴,因此一旦模型工作正常,您可能希望停用它。- 引數
value (bool) – 是否啟用驗證。
- property support: Optional[Constraint]#
返回一個
Constraint物件,表示此分佈的支援域。
ExponentialFamily#
- class torch.distributions.exp_family.ExponentialFamily(batch_shape=torch.Size([]), event_shape=torch.Size([]), validate_args=None)[source]#
Bases:
DistributionExponentialFamily 是屬於指數族的機率分佈的抽象基類,其機率質量/密度函式形式如下:
其中 表示自然引數, 表示充分統計量, 是給定族對的對數歸一化函式, 是載體測度。
注意
此類是 Distribution 類和屬於指數族的分佈之間的中間項,主要用於檢查 .entropy() 和解析 KL 散度方法的正確性。我們使用此類透過對數歸一化函式的 Bregman 散度來計算熵和 KL 散度(來源:Frank Nielsen 和 Richard Nock,Entropies and Cross-entropies of Exponential Families)。
Bernoulli#
- class torch.distributions.bernoulli.Bernoulli(probs=None, logits=None, validate_args=None)[source]#
Bases:
ExponentialFamily使用
probs或logits(但不能同時)引數化的 Bernoulli 分佈。樣本是二元的(0 或 1)。它們以機率 p 取值 1,以機率 1 - p 取值 0。
示例
>>> m = Bernoulli(torch.tensor([0.3])) >>> m.sample() # 30% chance 1; 70% chance 0 tensor([ 0.])
- 引數
- arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}#
- has_enumerate_support = True#
- support = Boolean()#
Beta#
- class torch.distributions.beta.Beta(concentration1, concentration0, validate_args=None)[source]#
Bases:
ExponentialFamily由
concentration1和concentration0引數化的 Beta 分佈。示例
>>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5])) >>> m.sample() # Beta distributed with concentration concentration1 and concentration0 tensor([ 0.1046])
- 引數
- arg_constraints = {'concentration0': GreaterThan(lower_bound=0.0), 'concentration1': GreaterThan(lower_bound=0.0)}#
- has_rsample = True#
- support = Interval(lower_bound=0.0, upper_bound=1.0)#
Binomial#
- class torch.distributions.binomial.Binomial(total_count=1, probs=None, logits=None, validate_args=None)[source]#
Bases:
Distribution建立由
total_count和probs或logits(但不能同時)引數化的二項分佈。total_count必須與probs/logits相容。示例
>>> m = Binomial(100, torch.tensor([0 , .2, .8, 1])) >>> x = m.sample() tensor([ 0., 22., 71., 100.]) >>> m = Binomial(torch.tensor([[5.], [10.]]), torch.tensor([0.5, 0.8])) >>> x = m.sample() tensor([[ 4., 5.], [ 7., 6.]])
- arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0), 'total_count': IntegerGreaterThan(lower_bound=0)}#
- has_enumerate_support = True#
- property support#
- 返回型別
_DependentProperty
Categorical#
- class torch.distributions.categorical.Categorical(probs=None, logits=None, validate_args=None)[source]#
Bases:
Distribution建立由
probs或logits(但不能同時)引數化的分類分佈。注意
它等價於
torch.multinomial()取樣的分佈。樣本是整數,取值範圍為 ,其中 K 是
probs.size(-1)。如果 probs 是長度為 K 的一維張量,則每個元素是取樣該索引處類別的相對機率。
如果 probs 是 N 維張量,則前 N-1 維被視為相對機率向量的批次。
注意
probs 引數必須非負、有限且總和非零,它將在最後一個維度上被歸一化為總和為 1。
probs將返回此歸一化值。 logits 引數將被解釋為未歸一化的對數機率,因此可以是任何實數。它也將被歸一化,以便生成的機率在最後一個維度上總和為 1。logits將返回此歸一化值。另請參閱:
torch.multinomial()示例
>>> m = Categorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) >>> m.sample() # equal probability of 0, 1, 2, 3 tensor(3)
- arg_constraints = {'logits': IndependentConstraint(Real(), 1), 'probs': Simplex()}#
- has_enumerate_support = True#
- property support#
- 返回型別
_DependentProperty
Cauchy#
- class torch.distributions.cauchy.Cauchy(loc, scale, validate_args=None)[source]#
Bases:
Distribution從柯西(洛倫茲)分佈取樣。均值為 0 的獨立正態分佈隨機變數之比的分佈遵循柯西分佈。
示例
>>> m = Cauchy(torch.tensor([0.0]), torch.tensor([1.0])) >>> m.sample() # sample from a Cauchy distribution with loc=0 and scale=1 tensor([ 2.3214])
- arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}#
- has_rsample = True#
- support = Real()#
Chi2#
- class torch.distributions.chi2.Chi2(df, validate_args=None)[source]#
Bases:
Gamma建立由形狀引數
df引數化的卡方分佈。這與Gamma(alpha=0.5*df, beta=0.5)完全等價。示例
>>> m = Chi2(torch.tensor([1.0])) >>> m.sample() # Chi2 distributed with shape df=1 tensor([ 0.1046])
- arg_constraints = {'df': GreaterThan(lower_bound=0.0)}#
ContinuousBernoulli#
- class torch.distributions.continuous_bernoulli.ContinuousBernoulli(probs=None, logits=None, lims=(0.499, 0.501), validate_args=None)[source]#
Bases:
ExponentialFamily使用
probs或logits(但不能同時)引數化的連續 Bernoulli 分佈。該分佈的支援域為 [0, 1],由 'probs'(在 (0,1) 內)或 'logits'(實數值)引數化。請注意,與 Bernoulli 不同,'probs' 不代表機率,'logits' 也不代表對數機率,但由於與 Bernoulli 的相似性,使用了相同的名稱。更多詳情請參閱 [1]。
示例
>>> m = ContinuousBernoulli(torch.tensor([0.3])) >>> m.sample() tensor([ 0.2538])
[1] The continuous Bernoulli: fixing a pervasive error in variational autoencoders, Loaiza-Ganem G and Cunningham JP, NeurIPS 2019. https://arxiv.org/abs/1907.06845
- arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}#
- has_rsample = True#
- support = Interval(lower_bound=0.0, upper_bound=1.0)#
Dirichlet#
- class torch.distributions.dirichlet.Dirichlet(concentration, validate_args=None)[原始碼]#
Bases:
ExponentialFamily建立一個由
concentration引數化的 Dirichlet 分佈。示例
>>> m = Dirichlet(torch.tensor([0.5, 0.5])) >>> m.sample() # Dirichlet distributed with concentration [0.5, 0.5] tensor([ 0.1046, 0.8954])
- 引數
concentration (Tensor) – 分佈的 concentration 引數(通常稱為 alpha)
- arg_constraints = {'concentration': IndependentConstraint(GreaterThan(lower_bound=0.0), 1)}#
- has_rsample = True#
- support = Simplex()#
Exponential#
- class torch.distributions.exponential.Exponential(rate, validate_args=None)[原始碼]#
Bases:
ExponentialFamily建立一個由
rate引數化的 Exponential 分佈。示例
>>> m = Exponential(torch.tensor([1.0])) >>> m.sample() # Exponential distributed with rate=1 tensor([ 0.1046])
- arg_constraints = {'rate': GreaterThan(lower_bound=0.0)}#
- has_rsample = True#
- support = GreaterThanEq(lower_bound=0.0)#
FisherSnedecor#
- class torch.distributions.fishersnedecor.FisherSnedecor(df1, df2, validate_args=None)[原始碼]#
Bases:
Distribution建立一個由
df1和df2引數化的 Fisher-Snedecor 分佈。示例
>>> m = FisherSnedecor(torch.tensor([1.0]), torch.tensor([2.0])) >>> m.sample() # Fisher-Snedecor-distributed with df1=1 and df2=2 tensor([ 0.2453])
- 引數
- arg_constraints = {'df1': GreaterThan(lower_bound=0.0), 'df2': GreaterThan(lower_bound=0.0)}#
- has_rsample = True#
- support = GreaterThan(lower_bound=0.0)#
Gamma#
- class torch.distributions.gamma.Gamma(concentration, rate, validate_args=None)[原始碼]#
Bases:
ExponentialFamily建立一個由
concentration和rate引數化的 Gamma 分佈。示例
>>> m = Gamma(torch.tensor([1.0]), torch.tensor([1.0])) >>> m.sample() # Gamma distributed with concentration=1 and rate=1 tensor([ 0.1046])
- 引數
- arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'rate': GreaterThan(lower_bound=0.0)}#
- has_rsample = True#
- support = GreaterThanEq(lower_bound=0.0)#
GeneralizedPareto#
- class torch.distributions.generalized_pareto.GeneralizedPareto(loc, scale, concentration, validate_args=None)[原始碼]#
Bases:
Distribution建立一個由
loc、scale和concentration引數化的 Generalized Pareto 分佈。Generalized Pareto 分佈是一類在實線上的連續機率分佈。特殊情況包括指數分佈(當
loc= 0,concentration= 0),Pareto 分佈(當concentration> 0,loc=scale/concentration),以及均勻分佈(當concentration= -1)。該分佈常用於對其他分佈的尾部進行建模。此實現基於 TensorFlow Probability 中的實現。
示例
>>> m = GeneralizedPareto(torch.tensor([0.1]), torch.tensor([2.0]), torch.tensor([0.4])) >>> m.sample() # sample from a Generalized Pareto distribution with loc=0.1, scale=2.0, and concentration=0.4 tensor([ 1.5623])
- 引數
- arg_constraints: dict[str, torch.distributions.constraints.Constraint] = {'concentration': Real(), 'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}#
- has_rsample = True#
- property mean#
- property mode#
- property support#
- 返回型別
_DependentProperty
- property variance#
Geometric#
- class torch.distributions.geometric.Geometric(probs=None, logits=None, validate_args=None)[原始碼]#
Bases:
Distribution建立一個由
probs引數化的 Geometric 分佈,其中probs是 Bernoulli 試驗成功的機率。注意
torch.distributions.geometric.Geometric()-th trial is the first success hence draws samples in , whereastorch.Tensor.geometric_()k-th trial is the first success hence draws samples in .示例
>>> m = Geometric(torch.tensor([0.3])) >>> m.sample() # underlying Bernoulli has 30% chance 1; 70% chance 0 tensor([ 2.])
- 引數
- arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}#
- support = IntegerGreaterThan(lower_bound=0)#
Gumbel#
- class torch.distributions.gumbel.Gumbel(loc, scale, validate_args=None)[原始碼]#
Bases:
TransformedDistribution從 Gumbel 分佈取樣。
示例
>>> m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0])) >>> m.sample() # sample from Gumbel distribution with loc=1, scale=2 tensor([ 1.0124])
- arg_constraints: dict[str, torch.distributions.constraints.Constraint] = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}#
- support = Real()#
HalfCauchy#
- class torch.distributions.half_cauchy.HalfCauchy(scale, validate_args=None)[原始碼]#
Bases:
TransformedDistribution建立一個由 scale 引數化的 half-Cauchy 分佈,其中
X ~ Cauchy(0, scale) Y = |X| ~ HalfCauchy(scale)
示例
>>> m = HalfCauchy(torch.tensor([1.0])) >>> m.sample() # half-cauchy distributed with scale=1 tensor([ 2.3214])
- arg_constraints: dict[str, torch.distributions.constraints.Constraint] = {'scale': GreaterThan(lower_bound=0.0)}#
- has_rsample = True#
- support = GreaterThanEq(lower_bound=0.0)#
HalfNormal#
- class torch.distributions.half_normal.HalfNormal(scale, validate_args=None)[原始碼]#
Bases:
TransformedDistribution建立一個由 scale 引數化的 half-normal 分佈,其中
X ~ Normal(0, scale) Y = |X| ~ HalfNormal(scale)
示例
>>> m = HalfNormal(torch.tensor([1.0])) >>> m.sample() # half-normal distributed with scale=1 tensor([ 0.1046])
- arg_constraints: dict[str, torch.distributions.constraints.Constraint] = {'scale': GreaterThan(lower_bound=0.0)}#
- has_rsample = True#
- support = GreaterThanEq(lower_bound=0.0)#
Independent#
- class torch.distributions.independent.Independent(base_distribution, reinterpreted_batch_ndims, validate_args=None)[原始碼]#
Bases:
Distribution,Generic[D]將分佈的某些 batch 維度重新解釋為 event 維度。
這主要用於更改
log_prob()結果的形狀。例如,要建立形狀與 Multivariate Normal 分佈相同的對角 Normal 分佈(以便它們可以互換),您可以>>> from torch.distributions.multivariate_normal import MultivariateNormal >>> from torch.distributions.normal import Normal >>> loc = torch.zeros(3) >>> scale = torch.ones(3) >>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale)) >>> [mvn.batch_shape, mvn.event_shape] [torch.Size([]), torch.Size([3])] >>> normal = Normal(loc, scale) >>> [normal.batch_shape, normal.event_shape] [torch.Size([3]), torch.Size([])] >>> diagn = Independent(normal, 1) >>> [diagn.batch_shape, diagn.event_shape] [torch.Size([]), torch.Size([3])]
- 引數
base_distribution (torch.distributions.distribution.Distribution) – a base distribution
reinterpreted_batch_ndims (int) – the number of batch dims to reinterpret as event dims
- arg_constraints: dict[str, torch.distributions.constraints.Constraint] = {}#
- base_dist: D#
- property support#
- 返回型別
_DependentProperty
逆Gamma分佈#
- class torch.distributions.inverse_gamma.InverseGamma(concentration, rate, validate_args=None)[原始碼]#
Bases:
TransformedDistribution建立一個由
concentration和rate引數化的逆 Gamma 分佈,其中X ~ Gamma(concentration, rate) Y = 1 / X ~ InverseGamma(concentration, rate)
示例
>>> m = InverseGamma(torch.tensor([2.0]), torch.tensor([3.0])) >>> m.sample() tensor([ 1.2953])
- 引數
- arg_constraints: dict[str, torch.distributions.constraints.Constraint] = {'concentration': GreaterThan(lower_bound=0.0), 'rate': GreaterThan(lower_bound=0.0)}#
- has_rsample = True#
- support = GreaterThan(lower_bound=0.0)#
庫馬拉斯瓦米分佈#
- class torch.distributions.kumaraswamy.Kumaraswamy(concentration1, concentration0, validate_args=None)[原始碼]#
Bases:
TransformedDistribution從庫馬拉斯瓦米分佈中取樣。
示例
>>> m = Kumaraswamy(torch.tensor([1.0]), torch.tensor([1.0])) >>> m.sample() # sample from a Kumaraswamy distribution with concentration alpha=1 and beta=1 tensor([ 0.1729])
- 引數
- arg_constraints: dict[str, torch.distributions.constraints.Constraint] = {'concentration0': GreaterThan(lower_bound=0.0), 'concentration1': GreaterThan(lower_bound=0.0)}#
- has_rsample = True#
- support = Interval(lower_bound=0.0, upper_bound=1.0)#
LKJCholesky#
- class torch.distributions.lkj_cholesky.LKJCholesky(dim, concentration=1.0, validate_args=None)[原始碼]#
Bases:
Distribution用於相關矩陣的下三角 Cholesky 分解的 LKJ 分佈。該分佈由
concentration引數 控制,使得從 Cholesky 分解生成的協方差矩陣 的機率正比於 。因此,當concentration == 1時,我們得到一個均勻分佈在 Cholesky 分解的相關矩陣上的分佈。L ~ LKJCholesky(dim, concentration) X = L @ L' ~ LKJCorr(dim, concentration)
請注意,此分佈取樣的是相關矩陣的 Cholesky 分解,而不是相關矩陣本身,因此與 [1] 中 LKJCorr 分佈的推導略有不同。取樣時,使用 [1] Section 3 中的 Onion 方法。
示例
>>> l = LKJCholesky(3, 0.5) >>> l.sample() # l @ l.T is a sample of a correlation 3x3 matrix tensor([[ 1.0000, 0.0000, 0.0000], [ 0.3516, 0.9361, 0.0000], [-0.1899, 0.4748, 0.8593]])
參考文獻
[1] Generating random correlation matrices based on vines and extended onion method (2009), Daniel Lewandowski, Dorota Kurowicka, Harry Joe. Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008
- arg_constraints = {'concentration': GreaterThan(lower_bound=0.0)}#
- support = CorrCholesky()#
拉普拉斯分佈#
- class torch.distributions.laplace.Laplace(loc, scale, validate_args=None)[原始碼]#
Bases:
Distribution建立一個由
loc和scale引數化的拉普拉斯分佈。示例
>>> m = Laplace(torch.tensor([0.0]), torch.tensor([1.0])) >>> m.sample() # Laplace distributed with loc=0, scale=1 tensor([ 0.1046])
- arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}#
- has_rsample = True#
- support = Real()#
對數正態分佈#
- class torch.distributions.log_normal.LogNormal(loc, scale, validate_args=None)[原始碼]#
Bases:
TransformedDistribution建立一個由
loc和scale引數化的對數正態分佈,其中X ~ Normal(loc, scale) Y = exp(X) ~ LogNormal(loc, scale)
示例
>>> m = LogNormal(torch.tensor([0.0]), torch.tensor([1.0])) >>> m.sample() # log-normal distributed with mean=0 and stddev=1 tensor([ 0.1046])
- arg_constraints: dict[str, torch.distributions.constraints.Constraint] = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}#
- has_rsample = True#
- support = GreaterThan(lower_bound=0.0)#
低秩多元正態分佈#
- class torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal(loc, cov_factor, cov_diag, validate_args=None)[原始碼]#
Bases:
Distribution建立一個具有低秩協方差矩陣的多元正態分佈,該分佈由
cov_factor和cov_diag引數化。covariance_matrix = cov_factor @ cov_factor.T + cov_diag
示例
>>> m = LowRankMultivariateNormal( ... torch.zeros(2), torch.tensor([[1.0], [0.0]]), torch.ones(2) ... ) >>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]` tensor([-0.2102, -0.5429])
- 引數
注意
當 cov_factor.shape[1] << cov_factor.shape[0] 時,由於 Woodbury 矩陣恆等式 和 矩陣行列式引理,避免了協方差矩陣的行列式和逆矩陣的計算。藉助這些公式,我們只需要計算小尺寸“電容”矩陣的行列式和逆矩陣。
capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
- arg_constraints = {'cov_diag': IndependentConstraint(GreaterThan(lower_bound=0.0), 1), 'cov_factor': IndependentConstraint(Real(), 2), 'loc': IndependentConstraint(Real(), 1)}#
- has_rsample = True#
- support = IndependentConstraint(Real(), 1)#
相同族混合分佈#
- class torch.distributions.mixture_same_family.MixtureSameFamily(mixture_distribution, component_distribution, validate_args=None)[原始碼]#
Bases:
DistributionMixtureSameFamily 分佈實現了一個 (批次的) 混合分佈,其中所有元件都來自同一分佈型別的不同引數化。它由一個“選擇分佈” Categorical (用於選擇 k 個元件) 和一個元件分佈引數化,即一個右側批次形狀 (等於 [k]) 的 Distribution,用於索引每個 (批次的) 元件。
示例
>>> # Construct Gaussian Mixture Model in 1D consisting of 5 equally >>> # weighted normal distributions >>> mix = D.Categorical(torch.ones(5,)) >>> comp = D.Normal(torch.randn(5,), torch.rand(5,)) >>> gmm = MixtureSameFamily(mix, comp) >>> # Construct Gaussian Mixture Model in 2D consisting of 5 equally >>> # weighted bivariate normal distributions >>> mix = D.Categorical(torch.ones(5,)) >>> comp = D.Independent(D.Normal( ... torch.randn(5,2), torch.rand(5,2)), 1) >>> gmm = MixtureSameFamily(mix, comp) >>> # Construct a batch of 3 Gaussian Mixture Models in 2D each >>> # consisting of 5 random weighted bivariate normal distributions >>> mix = D.Categorical(torch.rand(3,5)) >>> comp = D.Independent(D.Normal( ... torch.randn(3,5,2), torch.rand(3,5,2)), 1) >>> gmm = MixtureSameFamily(mix, comp)
- 引數
mixture_distribution (Categorical) – torch.distributions.Categorical 類例項。管理選擇元件的機率。類別數量必須與 component_distribution 的最右側批次維度匹配。必須具有標量 batch_shape 或 batch_shape 與 component_distribution.batch_shape[:-1] 匹配。
component_distribution (Distribution) – torch.distributions.Distribution 類例項。最右側批次維度索引元件。
- arg_constraints: dict[str, torch.distributions.constraints.Constraint] = {}#
- property component_distribution: Distribution#
- has_rsample = False#
- property mixture_distribution: Categorical#
- property support#
- 返回型別
_DependentProperty
多項分佈#
- class torch.distributions.multinomial.Multinomial(total_count=1, probs=None, logits=None, validate_args=None)[原始碼]#
Bases:
Distribution建立一個由
total_count和probs或logits(但不能同時) 引數化的多項分佈。probs的最內層維度索引類別。所有其他維度索引批次。注意,如果只調用
log_prob(),則total_count則無需指定 (請參見下面的示例)。注意
probs 引數必須是非負、有限且總和非零的,它將在最後一個維度上進行歸一化以使總和為 1。
probs將返回此歸一化值。 logits 引數將被解釋為未歸一化的對數機率,因此可以是任何實數。它同樣會被歸一化,以使由此產生的機率在最後一個維度上總和為 1。logits將返回此歸一化值。sample()需要為所有引數和樣本共享一個 total_count。log_prob()允許為每個引數和樣本使用不同的 total_count。
示例
>>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.])) >>> x = m.sample() # equal probability of 0, 1, 2, 3 tensor([ 21., 24., 30., 25.]) >>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x) tensor([-4.1338])
- arg_constraints = {'logits': IndependentConstraint(Real(), 1), 'probs': Simplex()}#
- property support#
- 返回型別
_DependentProperty
多元正態分佈#
- class torch.distributions.multivariate_normal.MultivariateNormal(loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)[原始碼]#
Bases:
Distribution建立一個由均值向量和協方差矩陣引數化的多元正態 (也稱為高斯) 分佈。
多元正態分佈可以根據正定協方差矩陣 或正定精度矩陣 或具有正值對角線的下三角矩陣 進行引數化,其中 . 此三角矩陣可透過例如協方差的 Cholesky 分解獲得。
示例
>>> m = MultivariateNormal(torch.zeros(2), torch.eye(2)) >>> m.sample() # normally distributed with mean=`[0,0]` and covariance_matrix=`I` tensor([-0.2102, -0.5429])
- 引數
注意
只能指定
covariance_matrix或precision_matrix或scale_tril中的一個。使用
scale_tril將更有效:所有內部計算都基於scale_tril。如果改為傳遞covariance_matrix或precision_matrix,它僅用於透過 Cholesky 分解計算相應的下三角矩陣。- arg_constraints = {'covariance_matrix': PositiveDefinite(), 'loc': IndependentConstraint(Real(), 1), 'precision_matrix': PositiveDefinite(), 'scale_tril': LowerCholesky()}#
- has_rsample = True#
- support = IndependentConstraint(Real(), 1)#
NegativeBinomial#
- 類 torch.distributions.negative_binomial.NegativeBinomial(total_count, probs=None, logits=None, validate_args=None)[源]#
Bases:
Distribution建立一個負二項分佈,即在達到
total_count次失敗之前的成功獨立且相同的伯努利試驗次數的分佈。每次伯努利試驗的成功機率為probs。- 引數
- arg_constraints = {'logits': Real(), 'probs': HalfOpenInterval(lower_bound=0.0, upper_bound=1.0), 'total_count': GreaterThanEq(lower_bound=0)}#
- support = IntegerGreaterThan(lower_bound=0)#
Normal#
- 類 torch.distributions.normal.Normal(loc, scale, validate_args=None)[源]#
Bases:
ExponentialFamily使用
loc和scale引數化的正態(也稱為高斯)分佈。示例
>>> m = Normal(torch.tensor([0.0]), torch.tensor([1.0])) >>> m.sample() # normally distributed with loc=0 and scale=1 tensor([ 0.1046])
- arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}#
- has_rsample = True#
- support = Real()#
OneHotCategorical#
- 類 torch.distributions.one_hot_categorical.OneHotCategorical(probs=None, logits=None, validate_args=None)[源]#
Bases:
Distribution樣本是大小為
probs.size(-1)的獨熱編碼向量。注意
probs 引數必須是非負的、有限的且非零和的,它將在最後一個維度上歸一化為和為 1。
probs將返回此歸一化值。 logits 引數將被解釋為未歸一化的對數機率,因此可以是任何實數。它也將被歸一化,使得結果機率在最後一個維度上之和為 1。logits將返回此歸一化值。另請參閱:
torch.distributions.Categorical(),用於probs和logits的規範。示例
>>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) >>> m.sample() # equal probability of 0, 1, 2, 3 tensor([ 0., 0., 0., 1.])
- arg_constraints = {'logits': IndependentConstraint(Real(), 1), 'probs': Simplex()}#
- has_enumerate_support = True#
- support = OneHot()#
Pareto#
- 類 torch.distributions.pareto.Pareto(scale, alpha, validate_args=None)[源]#
Bases:
TransformedDistribution從帕累託 I 型分佈中取樣。
示例
>>> m = Pareto(torch.tensor([1.0]), torch.tensor([1.0])) >>> m.sample() # sample from a Pareto distribution with scale=1 and alpha=1 tensor([ 1.5623])
- arg_constraints: dict[str, torch.distributions.constraints.Constraint] = {'alpha': GreaterThan(lower_bound=0.0), 'scale': GreaterThan(lower_bound=0.0)}#
- 屬性 support: Constraint#
- 返回型別
_DependentProperty
Poisson#
- 類 torch.distributions.poisson.Poisson(rate, validate_args=None)[源]#
Bases:
ExponentialFamily使用率引數
rate引數化的泊松分佈。樣本是非負整數,其機率質量函式 (pmf) 如下:
示例
>>> m = Poisson(torch.tensor([4])) >>> m.sample() tensor([ 3.])
- 引數
rate (數字, 張量) – 速率引數
- arg_constraints = {'rate': GreaterThanEq(lower_bound=0.0)}#
- support = IntegerGreaterThan(lower_bound=0)#
RelaxedBernoulli#
- 類 torch.distributions.relaxed_bernoulli.RelaxedBernoulli(temperature, probs=None, logits=None, validate_args=None)[源]#
Bases:
TransformedDistribution建立一個 RelaxedBernoulli 分佈,由
temperature引數化,以及probs或logits(但不能同時使用)。這是 Bernoulli 分佈的鬆弛版本,因此其值在 (0, 1) 範圍內,並且具有可重引數化的樣本。示例
>>> m = RelaxedBernoulli(torch.tensor([2.2]), ... torch.tensor([0.1, 0.2, 0.3, 0.99])) >>> m.sample() tensor([ 0.2951, 0.3442, 0.8918, 0.9021])
- arg_constraints: dict[str, torch.distributions.constraints.Constraint] = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}#
- base_dist: LogitRelaxedBernoulli#
- has_rsample = True#
- support = Interval(lower_bound=0.0, upper_bound=1.0)#
LogitRelaxedBernoulli#
- 類 torch.distributions.relaxed_bernoulli.LogitRelaxedBernoulli(temperature, probs=None, logits=None, validate_args=None)[源]#
Bases:
Distribution使用
probs或logits(但不能同時使用)引數化的 LogitRelaxedBernoulli 分佈,它是 RelaxedBernoulli 分佈的對數機率。有關更多詳細資訊,請參閱 [1]。樣本是 (0, 1) 範圍內的值的對數機率。有關更多詳細資訊,請參閱 [1]。
[1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables (Maddison et al., 2017)
[2] Categorical Reparametrization with Gumbel-Softmax (Jang et al., 2017)
- arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}#
- support = Real()#
RelaxedOneHotCategorical#
- 類 torch.distributions.relaxed_categorical.RelaxedOneHotCategorical(temperature, probs=None, logits=None, validate_args=None)[源]#
Bases:
TransformedDistribution建立一個 RelaxedOneHotCategorical 分佈,由
temperature引數化,以及probs或logits。這是OneHotCategorical分佈的鬆弛版本,因此其樣本在單純形上,並且是可重引數化的。示例
>>> m = RelaxedOneHotCategorical(torch.tensor([2.2]), ... torch.tensor([0.1, 0.2, 0.3, 0.4])) >>> m.sample() tensor([ 0.1294, 0.2324, 0.3859, 0.2523])
- arg_constraints: dict[str, torch.distributions.constraints.Constraint] = {'logits': IndependentConstraint(Real(), 1), 'probs': Simplex()}#
- base_dist: ExpRelaxedCategorical#
- has_rsample = True#
- support = Simplex()#
StudentT#
- 類 torch.distributions.studentT.StudentT(df, loc=0.0, scale=1.0, validate_args=None)[源]#
Bases:
Distribution使用自由度
df、均值loc和尺度scale引數化的 Student’s t 分佈。示例
>>> m = StudentT(torch.tensor([2.0])) >>> m.sample() # Student's t-distributed with degrees of freedom=2 tensor([ 0.1046])
- arg_constraints = {'df': GreaterThan(lower_bound=0.0), 'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}#
- has_rsample = True#
- support = Real()#
TransformedDistribution#
- 類 torch.distributions.transformed_distribution.TransformedDistribution(base_distribution, transforms, validate_args=None)[源]#
Bases:
DistributionDistribution 類的一個擴充套件,它將一系列變換應用於基本分佈。令 f 為應用的變換的組合
X ~ BaseDistribution Y = f(X) ~ TransformedDistribution(BaseDistribution, f) log p(Y) = log p(X) + log |det (dX/dY)|
請注意,
TransformedDistribution的.event_shape是其基本分佈和變換的最大形狀,因為變換可能會引入事件之間的相關性。對
TransformedDistribution用法的示例是# Building a Logistic Distribution # X ~ Uniform(0, 1) # f = a + b * logit(X) # Y ~ f(X) ~ Logistic(a, b) base_distribution = Uniform(0, 1) transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)] logistic = TransformedDistribution(base_distribution, transforms)
有關更多示例,請檢視
Gumbel、HalfCauchy、HalfNormal、LogNormal、Pareto、Weibull、RelaxedBernoulli和RelaxedOneHotCategorical的實現。- arg_constraints: dict[str, torch.distributions.constraints.Constraint] = {}#
- rsample(sample_shape=torch.Size([]))[源]#
生成一個 sample_shape 大小的可重引數化樣本,或者如果分佈引數是批次的,則生成 sample_shape 大小的可重引數化樣本批。首先從基本分佈取樣,然後為列表中的每個變換應用 transform()。
- 返回型別
- sample(sample_shape=torch.Size([]))[源]#
生成一個 sample_shape 大小的樣本,或者如果分佈引數是批次的,則生成 sample_shape 大小的樣本批。首先從基本分佈取樣,然後為列表中的每個變換應用 transform()。
- 屬性 support#
- 返回型別
_DependentProperty
Uniform#
- 類 torch.distributions.uniform.Uniform(low, high, validate_args=None)[源]#
Bases:
Distribution在半開區間
[low, high)中生成均勻分佈的隨機樣本。示例
>>> m = Uniform(torch.tensor([0.0]), torch.tensor([5.0])) >>> m.sample() # uniformly distributed in the range [0.0, 5.0) tensor([ 2.3418])
- 屬性 arg_constraints#
- has_rsample = True#
- 屬性 support#
- 返回型別
_DependentProperty
VonMises#
- 類 torch.distributions.von_mises.VonMises(loc, concentration, validate_args=None)[源]#
Bases:
Distribution一個圓 von Mises 分佈。
此實現使用極座標。
loc和value引數可以是任何實數(以便於無約束最佳化),但它們被解釋為模 2 pi 的角度。- 示例:
>>> m = VonMises(torch.tensor([1.0]), torch.tensor([1.0])) >>> m.sample() # von Mises distributed with loc=1 and concentration=1 tensor([1.9777])
- 引數
loc (torch.Tensor) – 以弧度為單位的角度。
concentration (torch.Tensor) – 濃度引數
- arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'loc': Real()}#
- has_rsample = False#
- sample(sample_shape=torch.Size([]))[源]#
von Mises 分佈的取樣演算法基於以下論文:D.J. Best 和 N.I. Fisher,“Efficient simulation of the von Mises distribution。” Applied Statistics (1979): 152-157。
取樣始終在內部以雙精度進行,以避免在_rejection_sample()函式中因濃度值過小而導致的掛起,這種情況在單精度下大約是1e-4時就會發生(請參閱問題#88443)。
- support = Real()#
Weibull#
- class torch.distributions.weibull.Weibull(scale, concentration, validate_args=None)[source]#
Bases:
TransformedDistribution從兩引數威布林分佈中取樣。
示例
>>> m = Weibull(torch.tensor([1.0]), torch.tensor([1.0])) >>> m.sample() # sample from a Weibull distribution with scale=1, concentration=1 tensor([ 0.4784])
- 引數
- arg_constraints: dict[str, torch.distributions.constraints.Constraint] = {'concentration': GreaterThan(lower_bound=0.0), 'scale': GreaterThan(lower_bound=0.0)}#
- support = GreaterThan(lower_bound=0.0)#
Wishart#
- class torch.distributions.wishart.Wishart(df, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)[source]#
Bases:
ExponentialFamily建立一個由對稱正定矩陣 或其喬列斯基分解 引數化的 Wishart 分佈。
示例
>>> m = Wishart(torch.Tensor([2]), covariance_matrix=torch.eye(2)) >>> m.sample() # Wishart distributed with mean=`df * I` and >>> # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j
- 引數
注意
只能指定
covariance_matrix、precision_matrix或scale_tril中的一個。使用scale_tril會更有效:內部所有計算都基於scale_tril。如果改用covariance_matrix或precision_matrix,則僅用於透過喬列斯基分解計算相應的下三角矩陣。‘torch.distributions.LKJCholesky’ 是一個受限的 Wishart 分佈。[1]參考文獻
[1] Wang, Z., Wu, Y. and Chu, H., 2018. On equivalence of the LKJ distribution and the restricted Wishart distribution. [2] Sawyer, S., 2007. Wishart Distributions and Inverse-Wishart Sampling. [3] Anderson, T. W., 2003. An Introduction to Multivariate Statistical Analysis (3rd ed.). [4] Odell, P. L. & Feiveson, A. H., 1966. A Numerical Procedure to Generate a SampleCovariance Matrix. JASA, 61(313):199-203. [5] Ku, Y.-C. & Bloomfield, P., 2010. Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX.
- property arg_constraints#
- has_rsample = True#
- rsample(sample_shape=torch.Size([]), max_try_correction=None)[source]#
警告
在某些情況下,基於Bartlett分解的取樣演算法可能會返回奇異的矩陣樣本。預設情況下會嘗試幾次來糾正奇異樣本,但最終可能仍會返回奇異樣本。奇異樣本在 .log_prob() 中可能返回 -inf 值。在這種情況下,使用者應驗證樣本,並相應地修復 df 的值或調整 .rsample 引數中的 max_try_correction 值。
- 返回型別
- support = PositiveDefinite()#
KL Divergence#
- torch.distributions.kl.kl_divergence(p, q)[source]#
計算兩個分佈之間的 Kullback-Leibler 散度 。
- 引數
p (Distribution) – 一個
Distribution物件。q (Distribution) – 一個
Distribution物件。
- 返回
形狀為 batch_shape 的 KL 散度批。
- 返回型別
- 引發
NotImplementedError – 如果分佈型別尚未透過
register_kl()註冊。
- KL 散度目前為以下分佈對實現:
Bernoulli和BernoulliBernoulli和PoissonBeta和BetaBeta和ContinuousBernoulliBeta和ExponentialBeta和GammaBeta和NormalBeta和ParetoBeta和UniformBinomial和BinomialCategorical和CategoricalCauchy和CauchyContinuousBernoulli和ContinuousBernoulliContinuousBernoulli和ExponentialContinuousBernoulli和NormalContinuousBernoulli和ParetoContinuousBernoulli和UniformDirichlet和DirichletExponential和BetaExponential和ContinuousBernoulliExponential和ExponentialExponential和GammaExponential和GumbelExponential和NormalExponential和ParetoExponential和UniformExponentialFamily和ExponentialFamilyGamma和BetaGamma和ContinuousBernoulliGamma和ExponentialGamma和GammaGamma和GumbelGamma和NormalGamma和ParetoGamma和UniformGeometric和GeometricGumbel和BetaGumbel和ContinuousBernoulliGumbel和ExponentialGumbel和GammaGumbel和GumbelGumbel和NormalGumbel和ParetoGumbel和UniformHalfNormal和HalfNormalIndependent和IndependentLaplace和BetaLaplace和ContinuousBernoulliLaplace和ExponentialLaplace和GammaLaplace和LaplaceLaplace和NormalLaplace和ParetoLaplace和UniformLowRankMultivariateNormal和LowRankMultivariateNormalLowRankMultivariateNormal和MultivariateNormalMultivariateNormal和LowRankMultivariateNormalMultivariateNormal和MultivariateNormalNormal和BetaNormal和ContinuousBernoulliNormal和ExponentialNormal和GammaNormal和GumbelNormal和LaplaceNormal和NormalNormal和ParetoNormal和UniformOneHotCategorical和OneHotCategoricalPareto和BetaPareto和ContinuousBernoulliPareto和ExponentialPareto和GammaPareto和NormalPareto和ParetoPareto和UniformPoisson和BernoulliPoisson和BinomialPoisson和PoissonTransformedDistribution和TransformedDistributionUniform和BetaUniform和ContinuousBernoulliUniform和ExponentialUniform和GammaUniform和GumbelUniform和NormalUniform和ParetoUniform和Uniform
- torch.distributions.kl.register_kl(type_p, type_q)[source]#
裝飾器,用於將成對函式註冊到
kl_divergence()。用法@register_kl(Normal, Normal) def kl_normal_normal(p, q): # insert implementation here
查詢返回最具體的 (型別,型別) 匹配,按子類排序。如果匹配不明確,則會引發 RuntimeWarning。例如,要解決不明確的情況
@register_kl(BaseP, DerivedQ) def kl_version1(p, q): ... @register_kl(DerivedP, BaseQ) def kl_version2(p, q): ...
您應該註冊第三個最具體的實現,例如
register_kl(DerivedP, DerivedQ)(kl_version1) # Break the tie.
Transforms#
- class torch.distributions.transforms.AffineTransform(loc, scale, event_dim=0, cache_size=0)[source]#
透過逐點仿射對映 進行變換。
- class torch.distributions.transforms.CatTransform(tseq, dim=0, lengths=None, cache_size=0)[source]#
變換函子,它以與
torch.cat()相容的方式,對 dim 維度上的每個子矩陣(長度為 lengths[dim])應用變換序列 tseq。示例
x0 = torch.cat([torch.range(1, 10), torch.range(1, 10)], dim=0) x = torch.cat([x0, x0], dim=0) t0 = CatTransform([ExpTransform(), identity_transform], dim=0, lengths=[10, 10]) t = CatTransform([t0, t0], dim=0, lengths=[20, 20]) y = t(x)
- class torch.distributions.transforms.ComposeTransform(parts, cache_size=0)[source]#
將多個變換鏈式組合。被組合的變換負責快取。
- class torch.distributions.transforms.CorrCholeskyTransform(cache_size=0)[source]#
將無約束實向量 (長度為 )轉換為 D 維相關矩陣的喬列斯基因子。此喬列斯基因子是一個下三角矩陣,每行的對角線為正且歐幾里得範數為單位。變換的處理方式如下:
首先,我們將 x 按行順序轉換為下三角矩陣。
對於下三角部分的每一行 ,我們應用
StickBreakingTransform類的 *有符號* 版本來變換 到單位歐幾里得長度向量,步驟如下: - 縮放到區間 域: . - 轉換為無符號域: . - 應用 . - 轉換回有符號域: .
- class torch.distributions.transforms.CumulativeDistributionTransform(distribution, cache_size=0)[source]#
透過機率分佈的累積分佈函式進行變換。
- 引數
distribution (Distribution) – 用於變換的累積分佈函式所基於的分佈。
示例
# Construct a Gaussian copula from a multivariate normal. base_dist = MultivariateNormal( loc=torch.zeros(2), scale_tril=LKJCholesky(2).sample(), ) transform = CumulativeDistributionTransform(Normal(0, 1)) copula = TransformedDistribution(base_dist, [transform])
- class torch.distributions.transforms.IndependentTransform(base_transform, reinterpreted_batch_ndims, cache_size=0)[source]#
另一個變換的包裝器,用於將
reinterpreted_batch_ndims個最右邊的維度視為相關的。這不會影響正向或反向變換,但在log_abs_det_jacobian()中會合並reinterpreted_batch_ndims個最右邊的維度。
- class torch.distributions.transforms.LowerCholeskyTransform(cache_size=0)[source]#
從無約束矩陣到對角線元素為非負的下三角矩陣的變換。
這對於用喬列斯基分解引數化正定矩陣很有用。
- class torch.distributions.transforms.PositiveDefiniteTransform(cache_size=0)[source]#
從無約束矩陣到正定矩陣的變換。
- class torch.distributions.transforms.ReshapeTransform(in_shape, out_shape, cache_size=0)[source]#
單位雅可比變換,用於重塑張量的最右邊部分。
請注意,
in_shape和out_shape必須具有相同的元素數量,這與torch.Tensor.reshape()相同。- 引數
in_shape (torch.Size) – 輸入的事件形狀。
out_shape (torch.Size) – 輸出的事件形狀。
cache_size (int) – 快取大小。如果為零,則不進行快取。如果為一,則快取最新的單個值。僅支援 0 和 1。(預設 0。)
- class torch.distributions.transforms.SoftplusTransform(cache_size=0)[source]#
透過對映 進行變換。當 時,實現會退回到線性函式。
- class torch.distributions.transforms.TanhTransform(cache_size=0)[source]#
透過對映 進行變換。
它等價於
ComposeTransform( [ AffineTransform(0.0, 2.0), SigmoidTransform(), AffineTransform(-1.0, 2.0), ] )
然而,這可能在數值上不穩定,因此建議使用 TanhTransform 代替。
請注意,當遇到 值時,應使用 cache_size=1。
- class torch.distributions.transforms.SoftmaxTransform(cache_size=0)[source]#
透過 並隨後進行歸一化,將非約束空間變換到單純形。
這不是雙射變換,不能用於 HMC。然而,它主要在座標系上操作(除了最後的歸一化),因此適用於逐座標最佳化演算法。
- class torch.distributions.transforms.StackTransform(tseq, dim=0, cache_size=0)[source]#
變換函子,它以與
torch.stack()相容的方式,逐個分量地應用一系列變換 tseq 到每個子矩陣上,dim 為維度。示例
x = torch.stack([torch.range(1, 10), torch.range(1, 10)], dim=1) t = StackTransform([ExpTransform(), identity_transform], dim=1) y = t(x)
- class torch.distributions.transforms.StickBreakingTransform(cache_size=0)[source]#
透過“折斷”過程將非約束空間變換到多一個維度的單純形。
此變換是透過對 Dirichlet 分佈的“折斷”構造中的迭代 Sigmoid 變換產生的:第一個 logit 透過 sigmoid 變換為第一個機率以及剩餘的機率,然後過程遞迴。
這是雙射變換,適合用於 HMC;然而,它混合了座標,不太適合最佳化。
- class torch.distributions.transforms.Transform(cache_size=0)[source]#
可逆變換的抽象類,其對數行列式雅可比行列式是可計算的。它們主要用於
torch.distributions.TransformedDistribution。快取對於那些逆變換計算成本高昂或數值不穩定的變換很有用。請注意,必須小心處理記憶值,因為自動微分圖可能會被反轉。例如,以下程式碼在有快取或無快取時都能正常工作
y = t(x) t.log_abs_det_jacobian(x, y).backward() # x will receive gradients.
然而,以下程式碼在快取時會因依賴關係反轉而報錯
y = t(x) z = t.inv(y) grad(z.sum(), [y]) # error because z is x
派生類應實現
_call()或_inverse()中的一個或兩個。將 bijective=True 設定為派生類還應實現log_abs_det_jacobian()。- 引數
cache_size (int) – 快取大小。如果為零,則不進行快取。如果為一,則快取最新的單個值。僅支援 0 和 1。
- 變數
domain (
Constraint) – 表示此變換有效輸入的約束。codomain (
Constraint) – 表示此變換有效輸出(作為逆變換的輸入)的約束。bijective (bool) – 此變換是否為雙射。當且僅當對於域中的每個
x和共域中的y,t.inv(t(x)) == x且t(t.inv(y)) == y時,變換t才為雙射。非雙射變換至少應滿足更弱的偽逆性質t(t.inv(t(x)) == t(x)和t.inv(t(t.inv(y))) == t.inv(y)。sign (int 或 Tensor) – 對於雙射單變數變換,此值應為 +1 或 -1,取決於變換是單調遞增還是遞減。
Constraints#
- class torch.distributions.constraints.Constraint[source]#
約束的抽象基類。
一個約束物件代表一個變數有效的區域,例如,變數可以在其中被最佳化。
- 變數
- torch.distributions.constraints.is_dependent(constraint)[source]#
檢查
constraint是否是_Dependent物件。- 引數
constraint – 一個
Constraint物件。- 返回
如果
constraint可以被細化為_Dependent型別,則返回 True,否則返回 False。- 返回型別
布林值
示例
>>> import torch >>> from torch.distributions import Bernoulli >>> from torch.distributions.constraints import is_dependent
>>> dist = Bernoulli(probs=torch.tensor([0.6], requires_grad=True)) >>> constraint1 = dist.arg_constraints["probs"] >>> constraint2 = dist.arg_constraints["logits"]
>>> for constraint in [constraint1, constraint2]: >>> if is_dependent(constraint): >>> continue
Constraint Registry#
PyTorch 提供了兩個全域性 ConstraintRegistry 物件,它們將 Constraint 物件連結到 Transform 物件。這兩個物件都接受約束並返回變換,但它們對雙射性的保證不同。
biject_to(constraint)查詢一個從constraints.real到給定constraint的雙射Transform。返回的變換保證.bijective = True,並且應該實現.log_abs_det_jacobian()。transform_to(constraint)查詢一個非必要雙射的Transform,從constraints.real到給定constraint。返回的變換不保證實現.log_abs_det_jacobian()。
transform_to() 登錄檔對於對機率分佈的有約束引數執行無約束最佳化很有用,這些引數由每個分佈的 .arg_constraints 字典指示。這些變換通常會過度引數化一個空間以避免旋轉;因此它們更適合逐座標最佳化演算法,如 Adam。
loc = torch.zeros(100, requires_grad=True)
unconstrained = torch.zeros(100, requires_grad=True)
scale = transform_to(Normal.arg_constraints["scale"])(unconstrained)
loss = -Normal(loc, scale).log_prob(data).sum()
biject_to() 登錄檔對於 Hamiltonian Monte Carlo(HMC)很有用,其中來自具有有約束 .support 的機率分佈的樣本在無約束空間中傳播,並且演算法通常是旋轉不變的。
dist = Exponential(rate)
unconstrained = torch.zeros(100, requires_grad=True)
sample = biject_to(dist.support)(unconstrained)
potential_energy = -dist.log_prob(sample).sum()
注意
transform_to 和 biject_to 不同的一個例子是 constraints.simplex: transform_to(constraints.simplex) 返回一個 SoftmaxTransform,它只是簡單地對輸入進行指數化和歸一化;這是一個廉價且主要為座標系的運算,適用於 SVI 等演算法。相比之下,biject_to(constraints.simplex) 返回一個 StickBreakingTransform,它將輸入雙射到維度少一個的空間;這是一個更昂貴、數值不太穩定的變換,但對於 HMC 等演算法是必需的。
biject_to 和 transform_to 物件可以透過使用者定義的約束和變換來擴充套件,使用它們的 .register() 方法,可以是作為單例約束上的函式
transform_to.register(my_constraint, my_transform)
或作為引數化約束的裝飾器
@transform_to.register(MyConstraintClass)
def my_factory(constraint):
assert isinstance(constraint, MyConstraintClass)
return MyTransform(constraint.param1, constraint.param2)
您可以透過建立一個新的 ConstraintRegistry 物件來建立自己的登錄檔。
- class torch.distributions.constraint_registry.ConstraintRegistry[source]#
連結約束到變換的登錄檔。
- register(constraint, factory=None)[source]#
在此登錄檔中註冊一個
Constraint子類。用法@my_registry.register(MyConstraintClass) def construct_transform(constraint): assert isinstance(constraint, MyConstraint) return MyTransform(constraint.arg_constraints)
- 引數
constraint (子類
Constraint) –Constraint的子類,或所需類的單例物件。factory (Callable) – 一個可呼叫物件,它接受一個約束物件作為輸入並返回一個
Transform物件。