評價此頁

torch.nn.attention.bias.causal_upper_left#

torch.nn.attention.bias.causal_upper_left(*size)[原始碼]#

建立一個左上角三角因果注意力偏差。

此函式生成一個左上角三角矩陣,用於表示因果注意力偏差,其對角線偏移設定正確,以便包含值與矩陣的左上角對齊。這等效於scaled_dot_product_attention中的is_causal=True引數。

構造此掩碼的等效 PyTorch 程式碼為:

torch.tril(torch.ones(size, dtype=torch.bool))

例如,當shape=(3,4)時,生成的偏差張量將是

[[1, 0, 0, 0],
 [1, 1, 0, 0],
 [1, 1, 1, 0]]
引數

size – 偏差矩陣的大小。

返回

左上角三角因果注意力偏差變體。

返回型別

CausalBias