評價此頁

torch.multinomial#

torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) LongTensor#

返回一個張量,其中每行包含從張量 input 中相應行上的多項式(更嚴格的定義是多元的,有關更多詳細資訊,請參閱 torch.distributions.multinomial.Multinomial)機率分佈中取樣的 num_samples 個索引。

注意

input 的行不需要求和為一(在這種情況下,我們使用這些值作為權重),但必須是非負、有限且和非零的。

索引按照取樣順序從左到右排序(第一個取樣放在第一列)。

如果 input 是一個向量,則 out 是一個大小為 num_samples 的向量。

如果 input 是一個有 m 行的矩陣,則 out 是一個形狀為 (m×num_samples)(m \times \text{num\_samples}) 的矩陣。

如果 replacementTrue,則樣本有放回地抽取。

如果不是,則樣本無放回地抽取,這意味著當某個行的樣本索引被抽取後,該行不能再被抽取該索引。

注意

無放回抽取時,num_samples 必須小於 input 中非零元素的數量(如果 input 是一個矩陣,則為每行非零元素的最小數量)。

引數
  • input (Tensor) – 包含機率的輸入張量

  • num_samples (int) – 要抽取的樣本數量

  • replacement (bool, optional) – 是否有放回地抽取

關鍵字引數
  • generator (torch.Generator, optional) – 用於取樣的偽隨機數生成器

  • out (Tensor, optional) – 輸出張量。

示例

>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights
>>> torch.multinomial(weights, 2)
tensor([1, 2])
>>> torch.multinomial(weights, 5) # ERROR!
RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement
>>> torch.multinomial(weights, 4, replacement=True)
tensor([ 2,  1,  1,  1])