評價此頁

torch.einsum#

torch.einsum(equation, *operands) Tensor[source]#

使用基於愛因斯坦求和約定的符號,對輸入 operands 的元素在指定維度上的乘積進行求和。

Einsum 允許透過基於愛因斯坦求和約定的簡寫格式來計算許多常見的**多維線性代數陣列運算**,該格式由 equation 指定。此格式的詳細資訊將在下文描述,但其基本思想是為輸入 operands 的每個維度分配一個下標,並定義哪些下標屬於輸出。然後,透過對輸入 operands 的元素進行乘積求和來計算輸出,求和的維度是那些下標不屬於輸出的維度。例如,矩陣乘法可以使用 einsum 來計算,如 torch.einsum(“ij,jk->ik”, A, B)。在這裡,j 是求和下標,i 和 k 是輸出下標(有關為什麼會這樣,請參閱下文)。

Equation

equation”字串以相同順序指定輸入 operands 每個維度的下標(`[a-zA-Z]` 中的字母),用逗號(‘,’)分隔每個運算元的下標,例如 ‘ij,jk’ 指定兩個二維運算元的下標。具有相同下標的維度必須是可廣播的,即它們的大小必須匹配或為 1。例外情況是,如果同一輸入運算元的下標重複出現,則該運算元中帶有此下標的維度大小必須匹配,並且該運算元將被替換為其沿這些維度的對角線。在 equation 中恰好出現一次的下標將成為輸出的一部分,並按字母順序升序排序。透過對輸入 operands 進行逐元素相乘來計算輸出,維度根據下標對齊,然後對不屬於輸出的下標所對應的維度進行求和。

可選地,可以透過在方程末尾新增箭頭(‘->’)並後跟輸出的下標來顯式定義輸出下標。例如,以下方程計算矩陣乘法的轉置:‘ij,jk->ki’。輸出下標必須在某個輸入運算元中至少出現一次,並在輸出中最多出現一次。

省略號(‘…’)可以用來代替下標來廣播省略號所覆蓋的維度。每個輸入運算元最多隻能包含一個省略號,它將覆蓋下標未覆蓋的維度。例如,對於具有 5 個維度的輸入運算元,方程 ‘ab…c’ 中的省略號將覆蓋第三個和第四個維度。省略號不需要覆蓋 operands 之間相同數量的維度,但省略號的“形狀”(它們所覆蓋的維度的尺寸)必須能夠一起廣播。如果輸出沒有用箭頭(‘->’)表示法顯式定義,則省略號將出現在輸出(最左邊的維度)中,排在輸入運算元中恰好出現一次的下標標籤之前。例如,以下方程實現了批矩陣乘法 ‘…ij,…jk’

最後幾點說明:方程中不同元素(下標、省略號、箭頭和逗號)之間可以包含空格,但像 ‘…’ 這樣的寫法無效。空字串 ‘’ 對於標量運算元是有效的。

注意

torch.einsum 處理省略號(‘…’)的方式與 NumPy 不同,它允許對省略號所覆蓋的維度進行求和,也就是說,省略號不需要成為輸出的一部分。

注意

請安裝 opt-einsum (https://optimized-einsum.readthedocs.io/en/stable/)以獲得更高效的 einsum。您可以在安裝 torch 時進行安裝,如下所示:pip install torch[opt-einsum],或者單獨安裝:pip install opt-einsum

如果 opt-einsum 可用,此函式將自動透過我們的 opt_einsum 後端 torch.backends.opt_einsum (“_”與“-”的混淆很令人困惑,我知道)最佳化收縮順序,從而加速計算和/或減少記憶體消耗。當輸入數量至少為三個時,就會發生這種最佳化,因為否則順序無關緊要。請注意,找到“最優”路徑是一個 NP-hard 問題,因此,opt-einsum 依賴於不同的啟發式方法來實現接近最優的結果。如果 opt-einsum 不可用,預設順序是從左到右進行收縮。

要繞過此預設行為,請新增以下行以停用 opt_einsum 並跳過路徑計算:torch.backends.opt_einsum.enabled = False

要指定 opt_einsum 用於計算收縮路徑的策略,請新增以下行:torch.backends.opt_einsum.strategy = 'auto'。預設策略是“auto”,我們還支援“greedy”和“optimal”。需要注意的是,“optimal”策略的執行時是輸入數量的階乘!有關更多詳細資訊,請參閱 opt_einsum 文件(https://optimized-einsum.readthedocs.io/en/stable/path_finding.html)。

注意

從 PyTorch 1.10 開始,torch.einsum() 也支援子列表格式(請參閱下面的示例)。在此格式中,每個運算元的下標由子列表指定,即範圍在 [0, 52) 內的整數列表。這些子列表跟在它們的運算元後面,並且可以在輸入末尾新增一個額外的子列表來指定輸出的下標,例如 torch.einsum(op1, sublist1, op2, sublist2, …, [subslist_out])。Python 的 Ellipsis 物件可以作為子列表提供,以啟用上面“Equation”部分中描述的廣播。

引數
  • equation (str) – 用於愛因斯坦求和的下標。

  • operands (List[Tensor]) – 用於計算愛因斯坦求和的張量。

返回型別

張量

示例

>>> # trace
>>> torch.einsum('ii', torch.randn(4, 4))
tensor(-1.2104)

>>> # diagonal
>>> torch.einsum('ii->i', torch.randn(4, 4))
tensor([-0.1034,  0.7952, -0.2433,  0.4545])

>>> # outer product
>>> x = torch.randn(5)
>>> y = torch.randn(4)
>>> torch.einsum('i,j->ij', x, y)
tensor([[ 0.1156, -0.2897, -0.3918,  0.4963],
        [-0.3744,  0.9381,  1.2685, -1.6070],
        [ 0.7208, -1.8058, -2.4419,  3.0936],
        [ 0.1713, -0.4291, -0.5802,  0.7350],
        [ 0.5704, -1.4290, -1.9323,  2.4480]])

>>> # batch matrix multiplication
>>> As = torch.randn(3, 2, 5)
>>> Bs = torch.randn(3, 5, 4)
>>> torch.einsum('bij,bjk->bik', As, Bs)
tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
        [-1.6706, -0.8097, -0.8025, -2.1183]],

        [[ 4.2239,  0.3107, -0.5756, -0.2354],
        [-1.4558, -0.3460,  1.5087, -0.8530]],

        [[ 2.8153,  1.8787, -4.3839, -1.2112],
        [ 0.3728, -2.1131,  0.0921,  0.8305]]])

>>> # with sublist format and ellipsis
>>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
        [-1.6706, -0.8097, -0.8025, -2.1183]],

        [[ 4.2239,  0.3107, -0.5756, -0.2354],
        [-1.4558, -0.3460,  1.5087, -0.8530]],

        [[ 2.8153,  1.8787, -4.3839, -1.2112],
        [ 0.3728, -2.1131,  0.0921,  0.8305]]])

>>> # batch permute
>>> A = torch.randn(2, 3, 4, 5)
>>> torch.einsum('...ij->...ji', A).shape
torch.Size([2, 3, 5, 4])

>>> # equivalent to torch.nn.functional.bilinear
>>> A = torch.randn(3, 5, 4)
>>> l = torch.randn(2, 5)
>>> r = torch.randn(2, 4)
>>> torch.einsum('bn,anm,bm->ba', l, A, r)
tensor([[-0.3430, -5.2405,  0.4494],
        [ 0.3311,  5.5201, -3.0356]])