CausalVariant#
- class torch.nn.attention.bias.CausalVariant(value)[source]#
用於注意力機制的因果變體列舉。
定義兩種因果偏差型別
UPPER_LEFT: 代表標準因果注意力的左上三角偏差。用於構建此偏差的等效 pytorch 程式碼為torch.tril(torch.ones(size, dtype=torch.bool))
例如,當
shape=(3,4)時,物化偏差張量將為[[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0]]
LOWER_RIGHT: 代表右下三角偏差,包含值對齊到矩陣的右下角。構造此掩碼的等效 PyTorch 程式碼為:
diagonal_offset = size[1] - size[0] torch.tril( torch.ones(size, dtype=torch.bool), diagonal=diagonal_offset, )
例如,當
shape=(3,4)時,物化偏差張量將為[[1, 1, 0, 0], [1, 1, 1, 0], [1, 1, 1, 1]]
請注意,當查詢和鍵/值張量的序列長度相等時,這些變體是等效的,因為三角矩陣是正方形的。
警告
此列舉是一個原型,可能會發生更改。