torch.functional.einsum#
- torch.functional.einsum(equation, *operands) Tensor[原始碼]#
使用基於愛因斯坦求和約定的符號,對輸入
operands的元素在指定維度上的乘積進行求和。Einsum 允許透過愛因斯坦求和約定(Einstein summation convention)的一種簡寫格式,由
equation指定,來計算許多常見的 N 維線性代數陣列操作。這種格式的細節將在下面描述,但其基本思想是為輸入operands的每個維度分配一個下標(subscript),並定義哪些下標屬於輸出。然後,透過對operands元素乘積沿不屬於輸出的維度進行求和來計算輸出。例如,矩陣乘法可以使用 einsum 計算,如 torch.einsum(“ij,jk->ik”, A, B)。在此,j 是求和下標,i 和 k 是輸出下標(更多細節請參閱下文)。Equation
The
equationstring specifies the subscripts (letters in [a-zA-Z]) for each dimension of the inputoperandsin the same order as the dimensions, separating subscripts for each operand by a comma (‘,’), e.g. ‘ij,jk’ specify subscripts for two 2D operands. The dimensions labeled with the same subscript must be broadcastable, that is, their size must either match or be 1. The exception is if a subscript is repeated for the same input operand, in which case the dimensions labeled with this subscript for this operand must match in size and the operand will be replaced by its diagonal along these dimensions. The subscripts that appear exactly once in theequationwill be part of the output, sorted in increasing alphabetical order. The output is computed by multiplying the inputoperandselement-wise, with their dimensions aligned based on the subscripts, and then summing out the dimensions whose subscripts are not part of the output.Optionally, the output subscripts can be explicitly defined by adding an arrow (‘->’) at the end of the equation followed by the subscripts for the output. For instance, the following equation computes the transpose of a matrix multiplication: ‘ij,jk->ki’. The output subscripts must appear at least once for some input operand and at most once for the output.
Ellipsis (‘…’) can be used in place of subscripts to broadcast the dimensions covered by the ellipsis. Each input operand may contain at most one ellipsis which will cover the dimensions not covered by subscripts, e.g. for an input operand with 5 dimensions, the ellipsis in the equation ‘ab…c’ cover the third and fourth dimensions. The ellipsis does not need to cover the same number of dimensions across the
operandsbut the ‘shape’ of the ellipsis (the size of the dimensions covered by them) must broadcast together. If the output is not explicitly defined with the arrow (‘->’) notation, the ellipsis will come first in the output (left-most dimensions), before the subscript labels that appear exactly once for the input operands. e.g. the following equation implements batch matrix multiplication ‘…ij,…jk’.A few final notes: the equation may contain whitespaces between the different elements (subscripts, ellipsis, arrow and comma) but something like ‘…’ is not valid. An empty string ‘’ is valid for scalar operands.
注意
torch.einsumhandles ellipsis (‘…’) differently from NumPy in that it allows dimensions covered by the ellipsis to be summed over, that is, ellipsis are not required to be part of the output.注意
Please install opt-einsum (https://optimized-einsum.readthedocs.io/en/stable/) in order to enroll into a more performant einsum. You can install when installing torch like so: pip install torch[opt-einsum] or by itself with pip install opt-einsum.
If opt-einsum is available, this function will automatically speed up computation and/or consume less memory by optimizing contraction order through our opt_einsum backend
torch.backends.opt_einsum(The _ vs - is confusing, I know). This optimization occurs when there are at least three inputs, since the order does not matter otherwise. Note that finding the optimal path is an NP-hard problem, thus, opt-einsum relies on different heuristics to achieve near-optimal results. If opt-einsum is not available, the default order is to contract from left to right.To bypass this default behavior, add the following to disable opt_einsum and skip path calculation:
torch.backends.opt_einsum.enabled = FalseTo specify which strategy you’d like for opt_einsum to compute the contraction path, add the following line:
torch.backends.opt_einsum.strategy = 'auto'. The default strategy is ‘auto’, and we also support ‘greedy’ and ‘optimal’. Disclaimer that the runtime of ‘optimal’ is factorial in the number of inputs! See more details in the opt_einsum documentation (https://optimized-einsum.readthedocs.io/en/stable/path_finding.html).注意
As of PyTorch 1.10
torch.einsum()also supports the sublist format (see examples below). In this format, subscripts for each operand are specified by sublists, list of integers in the range [0, 52). These sublists follow their operands, and an extra sublist can appear at the end of the input to specify the output’s subscripts., e.g. torch.einsum(op1, sublist1, op2, sublist2, …, [subslist_out]). Python’s Ellipsis object may be provided in a sublist to enable broadcasting as described in the Equation section above.- 引數
- 返回型別
示例
>>> # trace >>> torch.einsum('ii', torch.randn(4, 4)) tensor(-1.2104) >>> # diagonal >>> torch.einsum('ii->i', torch.randn(4, 4)) tensor([-0.1034, 0.7952, -0.2433, 0.4545]) >>> # outer product >>> x = torch.randn(5) >>> y = torch.randn(4) >>> torch.einsum('i,j->ij', x, y) tensor([[ 0.1156, -0.2897, -0.3918, 0.4963], [-0.3744, 0.9381, 1.2685, -1.6070], [ 0.7208, -1.8058, -2.4419, 3.0936], [ 0.1713, -0.4291, -0.5802, 0.7350], [ 0.5704, -1.4290, -1.9323, 2.4480]]) >>> # batch matrix multiplication >>> As = torch.randn(3, 2, 5) >>> Bs = torch.randn(3, 5, 4) >>> torch.einsum('bij,bjk->bik', As, Bs) tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]]) >>> # with sublist format and ellipsis >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2]) tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]]) >>> # batch permute >>> A = torch.randn(2, 3, 4, 5) >>> torch.einsum('...ij->...ji', A).shape torch.Size([2, 3, 5, 4]) >>> # equivalent to torch.nn.functional.bilinear >>> A = torch.randn(3, 5, 4) >>> l = torch.randn(2, 5) >>> r = torch.randn(2, 4) >>> torch.einsum('bn,anm,bm->ba', l, A, r) tensor([[-0.3430, -5.2405, 0.4494], [ 0.3311, 5.5201, -3.0356]])