評價此頁

torch.set_float32_matmul_precision#

torch.set_float32_matmul_precision(precision)[source]#

設定 float32 矩陣乘法的內部精度。

以較低的精度執行 float32 矩陣乘法可以顯著提高效能,並且在某些程式中精度損失的影響可以忽略不計。

支援三種設定

  • “highest”(最高),float32 矩陣乘法在內部計算中使用 float32 資料型別(24 位尾數,23 位顯式儲存)。

  • “high”(高),float32 矩陣乘法使用 TensorFloat32 資料型別(10 位尾數,顯式儲存)或將每個 float32 數視為兩個 bfloat16 數的和(約 16 位尾數,14 位顯式儲存),前提是存在相應的快速矩陣乘法演算法。否則,float32 矩陣乘法將按“highest”精度計算。有關 bfloat16 方法的更多資訊,請參閱下文。

  • “medium”(中等),float32 矩陣乘法使用 bfloat16 資料型別(8 位尾數,7 位顯式儲存)進行內部計算,前提是存在使用該資料型別進行內部計算的快速矩陣乘法演算法。否則,float32 矩陣乘法將按“high”精度計算。

使用“high”精度時,float32 乘法可能使用一種基於 bfloat16 的演算法,該演算法比簡單地截斷到較小的尾數位數(例如,TensorFloat32 為 10 位,bfloat16 顯式儲存為 7 位)更復雜。有關該演算法的完整描述,請參閱 [Henry2019]。在此簡要解釋,第一步是認識到我們可以將一個 float32 數完美地編碼為三個 bfloat16 數的和(因為 float32 有 23 位尾數,而 bfloat16 有 7 位顯式儲存,並且兩者具有相同的指數位數)。這意味著兩個 float32 數的乘積可以精確地表示為九個 bfloat16 數乘積的和。然後,我們可以透過丟棄其中一些乘積來權衡精度與速度。“high”精度演算法特別只保留了三個最重要的乘積,這恰好排除了所有涉及任一輸入最後 8 位尾數的乘積。這意味著我們可以將輸入表示為兩個 bfloat16 數的和,而不是三個。因為 bfloat16 熔合乘加 (FMA) 指令通常比 float32 指令快 10 倍以上,所以使用 bfloat16 精度進行三次乘法和 2 次加法比使用 float32 精度進行一次乘法要快。

Henry2019

http://arxiv.org/abs/1904.06376

注意

這不會改變 float32 矩陣乘法的輸出 dtype,它控制矩陣乘法的內部計算是如何執行的。

注意

這不會改變卷積運算的精度。其他標誌,如 torch.backends.cudnn.allow_tf32,可能會控制卷積運算的精度。

注意

此標誌目前僅影響一種原生裝置型別:CUDA。如果設定為“high”或“medium”,則在計算 float32 矩陣乘法時將使用 TensorFloat32 資料型別,這等同於將 torch.backends.cuda.matmul.allow_tf32 = True。當設定為“highest”(預設值)時,float32 資料型別用於內部計算,這等同於將 torch.backends.cuda.matmul.allow_tf32 = False

引數

precision (str) – 可以設定為“highest”(預設值)、“high”或“medium”(見上文)。