評價此頁

廣播語義#

創建於: 2017 年 4 月 27 日 | 最後更新於: 2021 年 1 月 31 日

許多 PyTorch 操作支援 NumPy 的廣播語義。有關詳細資訊,請參閱 https://numpy.org/doc/stable/user/basics.broadcasting.html

簡而言之,如果一個 PyTorch 操作支援廣播,那麼它的 Tensor 引數可以自動擴充套件為相同的大小(無需複製資料)。

通用語義#

如果滿足以下規則,則兩個 Tensor 稱為“可廣播”:

  • 每個 Tensor 至少有一個維度。

  • 當從最後一個維度開始迭代維度大小時,維度大小必須相等,其中一個大小為 1,或者其中一個不存在。

舉個例子

>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# same shapes are always broadcastable (i.e. the above rules always hold)

>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# x and y are not broadcastable, because x does not have at least 1 dimension

# can line up trailing dimensions
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(  3,1,1)
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist

# but:
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(  3,1,1)
# x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3

如果兩個 Tensor xy 是“可廣播”的,則結果 Tensor 的大小計算如下:

  • 如果 xy 的維度數量不相等,則在維度較少的 Tensor 的維度前面新增 1,使其長度相等。

  • 然後,對於每個維度大小,結果維度大小是 xy 在該維度上的大小的最大值。

舉個例子

# can line up trailing dimensions to make reading easier
>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty(  3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])

# but not necessary:
>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])

>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

原地語義#

一個複雜之處在於,原地操作不允許原地 Tensor 因廣播而改變形狀。

舉個例子

>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(3,1,1)
>>> (x.add_(y)).size()
torch.Size([5, 3, 4, 1])

# but:
>>> x=torch.empty(1,3,1)
>>> y=torch.empty(3,1,7)
>>> (x.add_(y)).size()
RuntimeError: The expanded size of the tensor (1) must match the existing size (7) at non-singleton dimension 2.

向後相容性#

PyTorch 的早期版本允許某些逐元素函式在具有不同形狀的 Tensor 上執行,只要每個 Tensor 的元素數量相等即可。然後,逐元素操作將透過將每個 Tensor 視為一維來執行。PyTorch 現在支援廣播,並且在 Tensor 不可廣播但元素數量相同的情況下,“一維”逐元素行為被視為已棄用,並將生成 Python 警告。

請注意,在兩個 Tensor 形狀不同但可廣播且元素數量相同的情況下,引入廣播可能會導致向後不相容的更改。例如:

>>> torch.add(torch.ones(4,1), torch.randn(4))

以前會產生大小為 torch.Size([4,1]) 的 Tensor,但現在會產生大小為 torch.Size([4,4]) 的 Tensor。為了幫助識別程式碼中可能存在的因廣播而導致的向後不相容的情況,您可以將 torch.utils.backcompat.broadcast_warning.enabled 設定為 True,這將在這種情況下生成 Python 警告。

舉個例子

>>> torch.utils.backcompat.broadcast_warning.enabled=True
>>> torch.add(torch.ones(4,1), torch.ones(4))
__main__:1: UserWarning: self and other do not have the same shape, but are broadcastable, and have the same number of elements.
Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.