評價此頁

torch.bernoulli#

torch.bernoulli(input: Tensor, *, generator: Optional[Generator], out: Optional[Tensor]) Tensor#

從伯努利分佈中抽取二元隨機數(0 或 1)。

輸入的 input 張量應包含用於抽取二元隨機數的機率。因此,input 中的所有值必須在以下範圍內:0inputi10 \leq \text{input}_i \leq 1.

輸出張量的第 ith\text{i}^{th} 個元素將根據 input 中給出的第 ith\text{i}^{th} 個機率值,抽取一個值為 11 的值。

outiBernoulli(p=inputi)\text{out}_{i} \sim \mathrm{Bernoulli}(p = \text{input}_{i})

返回的 out 張量只包含 0 或 1 的值,其形狀與 input 相同。

out 可以具有整數 dtype,但 input 必須具有浮點數 dtype

引數

input (Tensor) – 用於伯努利分佈的機率值的輸入張量

關鍵字引數
  • generator (torch.Generator, optional) – 用於取樣的偽隨機數生成器

  • out (Tensor, optional) – 輸出張量。

示例

>>> a = torch.empty(3, 3).uniform_(0, 1)  # generate a uniform random matrix with range [0, 1]
>>> a
tensor([[ 0.1737,  0.0950,  0.3609],
        [ 0.7148,  0.0289,  0.2676],
        [ 0.9456,  0.8937,  0.7202]])
>>> torch.bernoulli(a)
tensor([[ 1.,  0.,  0.],
        [ 0.,  0.,  0.],
        [ 1.,  1.,  1.]])

>>> a = torch.ones(3, 3) # probability of drawing "1" is 1
>>> torch.bernoulli(a)
tensor([[ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.]])
>>> a = torch.zeros(3, 3) # probability of drawing "1" is 0
>>> torch.bernoulli(a)
tensor([[ 0.,  0.,  0.],
        [ 0.,  0.,  0.],
        [ 0.,  0.,  0.]])