torch.multinomial#
- torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) LongTensor#
返回一個張量,其中每行包含從張量
input中相應行上的多項式(更嚴格的定義是多元的,有關更多詳細資訊,請參閱torch.distributions.multinomial.Multinomial)機率分佈中取樣的num_samples個索引。注意
input的行不需要求和為一(在這種情況下,我們使用這些值作為權重),但必須是非負、有限且和非零的。索引按照取樣順序從左到右排序(第一個取樣放在第一列)。
如果
input是一個向量,則out是一個大小為num_samples的向量。如果
input是一個有 m 行的矩陣,則out是一個形狀為 的矩陣。如果
replacement為True,則樣本有放回地抽取。如果不是,則樣本無放回地抽取,這意味著當某個行的樣本索引被抽取後,該行不能再被抽取該索引。
注意
無放回抽取時,
num_samples必須小於input中非零元素的數量(如果input是一個矩陣,則為每行非零元素的最小數量)。- 引數
- 關鍵字引數
generator (
torch.Generator, optional) – 用於取樣的偽隨機數生成器out (Tensor, optional) – 輸出張量。
示例
>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights >>> torch.multinomial(weights, 2) tensor([1, 2]) >>> torch.multinomial(weights, 5) # ERROR! RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement >>> torch.multinomial(weights, 4, replacement=True) tensor([ 2, 1, 1, 1])