評價此頁

CausalVariant#

class torch.nn.attention.bias.CausalVariant(value)[source]#

用於注意力機制的因果變體列舉。

定義兩種因果偏差型別

UPPER_LEFT: 代表標準因果注意力的左上三角偏差。用於構建此偏差的等效 pytorch 程式碼為

torch.tril(torch.ones(size, dtype=torch.bool))

例如,當 shape=(3,4) 時,物化偏差張量將為

[[1, 0, 0, 0],
 [1, 1, 0, 0],
 [1, 1, 1, 0]]

LOWER_RIGHT: 代表右下三角偏差,包含值對齊到矩陣的右下角。

構造此掩碼的等效 PyTorch 程式碼為:

diagonal_offset = size[1] - size[0]
torch.tril(
    torch.ones(size, dtype=torch.bool),
    diagonal=diagonal_offset,
)

例如,當 shape=(3,4) 時,物化偏差張量將為

[[1, 1, 0, 0],
 [1, 1, 1, 0],
 [1, 1, 1, 1]]

請注意,當查詢和鍵/值張量的序列長度相等時,這些變體是等效的,因為三角矩陣是正方形的。

警告

此列舉是一個原型,可能會發生更改。