torch.nn.attention.bias.CausalBias#
- class torch.nn.attention.bias.CausalBias(variant, seq_len_q, seq_len_kv)[原始碼]#
表示因果注意力模式的偏置。有關偏置結構的概述,請參閱
CausalVariant列舉。此類用於定義因果(三角形)注意力偏置。對於構建偏置,存在兩個工廠函式:
causal_upper_left()和causal_lower_right()。示例
from torch.nn.attention.bias import causal_lower_right bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8 # Create a lower-right causal bias attn_bias = causal_lower_right(seqlen_q, seqlen_kv) q = torch.randn( bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16 ) k = torch.randn( bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16 ) v = torch.randn( bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16 ) out = F.scaled_dot_product_attention(q, k, v, attn_bias)
警告
此類為原型,可能會發生更改。