torch.nn.attention.bias.causal_upper_left#
- torch.nn.attention.bias.causal_upper_left(*size)[原始碼]#
建立一個左上角三角因果注意力偏差。
此函式生成一個左上角三角矩陣,用於表示因果注意力偏差,其對角線偏移設定正確,以便包含值與矩陣的左上角對齊。這等效於scaled_dot_product_attention中的is_causal=True引數。
構造此掩碼的等效 PyTorch 程式碼為:
torch.tril(torch.ones(size, dtype=torch.bool))
例如,當shape=(3,4)時,生成的偏差張量將是
[[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0]]
- 引數
size – 偏差矩陣的大小。
- 返回
左上角三角因果注意力偏差變體。
- 返回型別