torch.set_float32_matmul_precision#
- torch.set_float32_matmul_precision(precision)[source]#
設定 float32 矩陣乘法的內部精度。
以較低的精度執行 float32 矩陣乘法可以顯著提高效能,並且在某些程式中精度損失的影響可以忽略不計。
支援三種設定
“highest”(最高),float32 矩陣乘法在內部計算中使用 float32 資料型別(24 位尾數,23 位顯式儲存)。
“high”(高),float32 矩陣乘法使用 TensorFloat32 資料型別(10 位尾數,顯式儲存)或將每個 float32 數視為兩個 bfloat16 數的和(約 16 位尾數,14 位顯式儲存),前提是存在相應的快速矩陣乘法演算法。否則,float32 矩陣乘法將按“highest”精度計算。有關 bfloat16 方法的更多資訊,請參閱下文。
“medium”(中等),float32 矩陣乘法使用 bfloat16 資料型別(8 位尾數,7 位顯式儲存)進行內部計算,前提是存在使用該資料型別進行內部計算的快速矩陣乘法演算法。否則,float32 矩陣乘法將按“high”精度計算。
使用“high”精度時,float32 乘法可能使用一種基於 bfloat16 的演算法,該演算法比簡單地截斷到較小的尾數位數(例如,TensorFloat32 為 10 位,bfloat16 顯式儲存為 7 位)更復雜。有關該演算法的完整描述,請參閱 [Henry2019]。在此簡要解釋,第一步是認識到我們可以將一個 float32 數完美地編碼為三個 bfloat16 數的和(因為 float32 有 23 位尾數,而 bfloat16 有 7 位顯式儲存,並且兩者具有相同的指數位數)。這意味著兩個 float32 數的乘積可以精確地表示為九個 bfloat16 數乘積的和。然後,我們可以透過丟棄其中一些乘積來權衡精度與速度。“high”精度演算法特別只保留了三個最重要的乘積,這恰好排除了所有涉及任一輸入最後 8 位尾數的乘積。這意味著我們可以將輸入表示為兩個 bfloat16 數的和,而不是三個。因為 bfloat16 熔合乘加 (FMA) 指令通常比 float32 指令快 10 倍以上,所以使用 bfloat16 精度進行三次乘法和 2 次加法比使用 float32 精度進行一次乘法要快。
注意
這不會改變 float32 矩陣乘法的輸出 dtype,它控制矩陣乘法的內部計算是如何執行的。
注意
這不會改變卷積運算的精度。其他標誌,如 torch.backends.cudnn.allow_tf32,可能會控制卷積運算的精度。
注意
此標誌目前僅影響一種原生裝置型別:CUDA。如果設定為“high”或“medium”,則在計算 float32 矩陣乘法時將使用 TensorFloat32 資料型別,這等同於將 torch.backends.cuda.matmul.allow_tf32 = True。當設定為“highest”(預設值)時,float32 資料型別用於內部計算,這等同於將 torch.backends.cuda.matmul.allow_tf32 = False。
- 引數
precision (str) – 可以設定為“highest”(預設值)、“high”或“medium”(見上文)。