torch.einsum#
- torch.einsum(equation, *operands) Tensor[source]#
使用基於愛因斯坦求和約定的符號,對輸入
operands的元素在指定維度上的乘積進行求和。Einsum 允許透過基於愛因斯坦求和約定的簡寫格式來計算許多常見的**多維線性代數陣列運算**,該格式由
equation指定。此格式的詳細資訊將在下文描述,但其基本思想是為輸入operands的每個維度分配一個下標,並定義哪些下標屬於輸出。然後,透過對輸入operands的元素進行乘積求和來計算輸出,求和的維度是那些下標不屬於輸出的維度。例如,矩陣乘法可以使用 einsum 來計算,如 torch.einsum(“ij,jk->ik”, A, B)。在這裡,j 是求和下標,i 和 k 是輸出下標(有關為什麼會這樣,請參閱下文)。Equation
“
equation”字串以相同順序指定輸入operands每個維度的下標(`[a-zA-Z]` 中的字母),用逗號(‘,’)分隔每個運算元的下標,例如 ‘ij,jk’ 指定兩個二維運算元的下標。具有相同下標的維度必須是可廣播的,即它們的大小必須匹配或為 1。例外情況是,如果同一輸入運算元的下標重複出現,則該運算元中帶有此下標的維度大小必須匹配,並且該運算元將被替換為其沿這些維度的對角線。在equation中恰好出現一次的下標將成為輸出的一部分,並按字母順序升序排序。透過對輸入operands進行逐元素相乘來計算輸出,維度根據下標對齊,然後對不屬於輸出的下標所對應的維度進行求和。可選地,可以透過在方程末尾新增箭頭(‘->’)並後跟輸出的下標來顯式定義輸出下標。例如,以下方程計算矩陣乘法的轉置:‘ij,jk->ki’。輸出下標必須在某個輸入運算元中至少出現一次,並在輸出中最多出現一次。
省略號(‘…’)可以用來代替下標來廣播省略號所覆蓋的維度。每個輸入運算元最多隻能包含一個省略號,它將覆蓋下標未覆蓋的維度。例如,對於具有 5 個維度的輸入運算元,方程 ‘ab…c’ 中的省略號將覆蓋第三個和第四個維度。省略號不需要覆蓋
operands之間相同數量的維度,但省略號的“形狀”(它們所覆蓋的維度的尺寸)必須能夠一起廣播。如果輸出沒有用箭頭(‘->’)表示法顯式定義,則省略號將出現在輸出(最左邊的維度)中,排在輸入運算元中恰好出現一次的下標標籤之前。例如,以下方程實現了批矩陣乘法 ‘…ij,…jk’。最後幾點說明:方程中不同元素(下標、省略號、箭頭和逗號)之間可以包含空格,但像 ‘…’ 這樣的寫法無效。空字串 ‘’ 對於標量運算元是有效的。
注意
torch.einsum處理省略號(‘…’)的方式與 NumPy 不同,它允許對省略號所覆蓋的維度進行求和,也就是說,省略號不需要成為輸出的一部分。注意
請安裝 opt-einsum (https://optimized-einsum.readthedocs.io/en/stable/)以獲得更高效的 einsum。您可以在安裝 torch 時進行安裝,如下所示:pip install torch[opt-einsum],或者單獨安裝:pip install opt-einsum。
如果 opt-einsum 可用,此函式將自動透過我們的 opt_einsum 後端
torch.backends.opt_einsum(“_”與“-”的混淆很令人困惑,我知道)最佳化收縮順序,從而加速計算和/或減少記憶體消耗。當輸入數量至少為三個時,就會發生這種最佳化,因為否則順序無關緊要。請注意,找到“最優”路徑是一個 NP-hard 問題,因此,opt-einsum 依賴於不同的啟發式方法來實現接近最優的結果。如果 opt-einsum 不可用,預設順序是從左到右進行收縮。要繞過此預設行為,請新增以下行以停用 opt_einsum 並跳過路徑計算:
torch.backends.opt_einsum.enabled = False要指定 opt_einsum 用於計算收縮路徑的策略,請新增以下行:
torch.backends.opt_einsum.strategy = 'auto'。預設策略是“auto”,我們還支援“greedy”和“optimal”。需要注意的是,“optimal”策略的執行時是輸入數量的階乘!有關更多詳細資訊,請參閱 opt_einsum 文件(https://optimized-einsum.readthedocs.io/en/stable/path_finding.html)。注意
從 PyTorch 1.10 開始,
torch.einsum()也支援子列表格式(請參閱下面的示例)。在此格式中,每個運算元的下標由子列表指定,即範圍在 [0, 52) 內的整數列表。這些子列表跟在它們的運算元後面,並且可以在輸入末尾新增一個額外的子列表來指定輸出的下標,例如 torch.einsum(op1, sublist1, op2, sublist2, …, [subslist_out])。Python 的 Ellipsis 物件可以作為子列表提供,以啟用上面“Equation”部分中描述的廣播。示例
>>> # trace >>> torch.einsum('ii', torch.randn(4, 4)) tensor(-1.2104) >>> # diagonal >>> torch.einsum('ii->i', torch.randn(4, 4)) tensor([-0.1034, 0.7952, -0.2433, 0.4545]) >>> # outer product >>> x = torch.randn(5) >>> y = torch.randn(4) >>> torch.einsum('i,j->ij', x, y) tensor([[ 0.1156, -0.2897, -0.3918, 0.4963], [-0.3744, 0.9381, 1.2685, -1.6070], [ 0.7208, -1.8058, -2.4419, 3.0936], [ 0.1713, -0.4291, -0.5802, 0.7350], [ 0.5704, -1.4290, -1.9323, 2.4480]]) >>> # batch matrix multiplication >>> As = torch.randn(3, 2, 5) >>> Bs = torch.randn(3, 5, 4) >>> torch.einsum('bij,bjk->bik', As, Bs) tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]]) >>> # with sublist format and ellipsis >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2]) tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]]) >>> # batch permute >>> A = torch.randn(2, 3, 4, 5) >>> torch.einsum('...ij->...ji', A).shape torch.Size([2, 3, 5, 4]) >>> # equivalent to torch.nn.functional.bilinear >>> A = torch.randn(3, 5, 4) >>> l = torch.randn(2, 5) >>> r = torch.randn(2, 4) >>> torch.einsum('bn,anm,bm->ba', l, A, r) tensor([[-0.3430, -5.2405, 0.4494], [ 0.3311, 5.5201, -3.0356]])