近日,清華大學團隊打造了首個用于推理加速的微縮版 FP4 注意力機制——SageAttention3,在英偉達 RTX5090 上實現了 1038TOPS 的計算性能。相比此前在英偉達 RTX5090 上計算性能最快的、由美國斯坦福大學提出的 FlashAttention,SageAttention3 的計算性能快了 5 倍。實驗表明,SageAttention3 能夠加速各種模型,并且不會導致端到端質量指標的下降。
與此同時,研究團隊還打造了首個用于訓練加速的可訓練 8 比特注意力機制——SageBwd,并探討了它在訓練任務中的可行性。其發現,8 比特注意力機制可以在微調任務中實現無損性能,不過在當前階段的預訓練任務中仍存在一定局限性。
(來源:arXiv)
由于注意力機制的時間復雜度是 n2,因此注意力機制的效率非常重要。為此,他們通過兩個關鍵貢獻提高了注意力的效率:首先,研究團隊利用英偉達 Blackwell GPU 中的新 FP4 Tensor 內核來加速注意力計算。實驗表明,SageAttention3 能夠以即插即用的方式加速各種模型的推理。其次,研究團隊在訓練任務中率先采用了低比特注意力機制,而此前包括 FlashAttention3 和 SageAttention 在內的現有低比特注意力機制僅僅關注推理。
據該研究團隊所知,本次研究首次實現了面向推理加速的 FP4 注意力機制設計,并開創性地探索了低比特注意力在大型模型訓練中的可行性。目前,相關代碼已開源:https://github.com/thu-ml/SageAttention。
解決兩大障礙和一個難點
研究團隊在論文中表示,FP4 注意力機制面臨兩個主要障礙,而 8 比特可訓練注意力機制則面臨著一個關鍵難點。具體來說:
第一個問題是:FP4 量化的數值表示范圍極為有限(僅能表示 15 個可取值),導致無論是逐張量(per-tensor)還是逐詞元(per-token)的量化方法,均無法有效保持模型精度。
第二個問題是:注意力圖 P 主要由 [0,1] 范圍內的小值組成。(注:注意力圖 P 是 Self-Attention 中的核心輸出矩陣,表示輸入序列中所有位置之間的相關性權重。)若直接量化為 FP4 格式,這些數值會迫使擴展因子的動態范圍被極度壓縮。然而,硬件要求量化因子必須采用 FP8 數據類型,這一限制導致縮放因子以 FP8 格式表示時會產生顯著的精度損失。
第三個問題是:在訓練過程中使用 8 比特注意力機制時,研究團隊發現注意力圖的梯度特別容易受到量化誤差的影響,從而導致輸入梯度中的誤差累積。
為了解決第一個問題,研究團隊提出針對注意力機制中的兩次矩陣乘法,即 QK? 和 PV 中使用 FP4 微縮放量化方法。通過將量化組大小限制為 1x16(而非基于張量或通道),讓本次方法在提高 FP4 量化精度的同時,能夠有效抑制每個塊內的異常值影響。
為了解決第二個問題,研究團隊提出了一種針對注意力圖 P 的兩級量化方法,從而充分利用了 FP8 縮放因子的表示范圍,提高了注意力圖 P 的量化精度。具體而言,該方法首先通過逐 token 量化將每個 token 的數值范圍歸一化至 [0, 448 × 6],隨后采用 FP4 微縮放量化來提升精度。
為了解決第三個問題,研究團隊在反向傳播涉及的五個矩陣乘法運算中,識別出對精度最為敏感的那個,并將其精度保持在 FP16 級別。
FP4 注意推理加速以及硬件實現與優化
在數據類型的確定上,FP4 數據類型有著兩種選擇。第一個選擇是 NVFP4,其數據類型為 E2M1,量化塊大小為 1×16,擴展因子為 E4M3 數據類型。第二個選擇是 MXFP4,它也是 E2M1 數據類型,然而其量化塊大小為 1×32,擴展因子為 E8M0 數據類型。
一番對比之后,研究團隊選擇了 NVFP4,這是因為 NVFP4 在注意力量化方面的精度遠高于 MXFP4。下表展示了在 AI 視頻生成模型 CogVideoX 所有層上使用實數 Q、K、V 的 MXFP4 和 NVFP4 的準確性。結果表明,NVFP4 的精度優于 MXFP4。
(來源:arXiv)
不同于 FP16,在 FP4 的矩陣乘法中,FP32 累加器的內存布局與其操作數 A 的寄存器布局不同。如果通過線程間數據交換來匹配操作數 A 的布局,會導致內核性能下降。研究團隊的方法是通過對 P tile 的列進行置換,來調整累加器的布局。為了保證矩陣乘法的正確性,研究團隊相應地重新排列 K 的列,這一過程可以與量化內核融合處理。
進行微縮放量化時,需要找到每行連續 16 個元素中的最大值。然而,這 16 個元素分布在 4 個線程中,這就需要線程內部先求最大值,再通過線程間的 shuffle 操作進行歸并,這大大拖慢了內核的執行速度。研究團隊針對這一做法進行了優化,即把量化過程與在線 softmax 融合處理,與此同時這種融合還能計算每行的最大值。
(來源:arXiv)
在傳統的 warp 專用內核中,消費者線程束通常同時執行矩陣乘法和存儲操作,而生產者線程束只是負責加載輸入數據,消費者線程束之間通過乒乓調度(ping-pong)調度實現階段重疊。
然而,在研究團隊的 FP4 注意力內核中,由于寄存器資源受限,這種方式無法實現。因此,研究團隊設計了新的方案,即在生產者線程束之間進行乒乓調度:當一個生產者線程束為下一次矩陣乘法操作加載輸入數據時,另一個生產者線程束同時將輸出結果存儲到全局內存中,而消費者線程束則僅負責將矩陣乘法的結果從寄存器轉移到共享內存中。
通過采用這種新穎的設計,讓他們在寄存器數量的限制下,實現了矩陣乘法和全局內存存儲操作的重疊,從而提高了吞吐量。
將 INT8 注意力用于訓練,并開展相關實驗
據了解,低比特量化注意力相關工作,比如 FlashAttention3 和 SageAttention,僅適用于推理場景。
如前所述,研究團隊提出了一種用于訓練的 INT8 注意力機制——SageBwd。該機制將注意力計算中的七個矩陣乘法里的六個量化為 INT8 精度,同時在微調任務中實現了零性能損失。
實驗中,研究團隊驗證了 SageAttention3 和 SageBwd 在語言、圖像和視頻生成等多種代表性模型中的有效性。
具體來說,他們在以下方面進行了實驗:
在文本到文本任務的測試實驗中,使用的是 Qwen2.5 和 Llama3.2;在文本到視頻任務的測試實驗中,使用的是 CogvideoX、HunyuanVideo 和 Mochi;在文本到圖像任務的測試實驗中,使用的是 Flux 和 Stable-Diffusion3.5。
研究團隊將本次方法與 FlashAttention2、xformers、SageAttention 和 SageAtteention2 進行了比較。
需要說明的是,FlashAttention3 只能在英偉達 Hopper GPU 上運行,因此 FlashAttention 2 已經是英偉達 RTX5090 和英偉達 RTX4090 上能運行的最快版本。
下圖展示了 SageAttention3 及其基線模型在 RTX 5090 上的內核運行速度。可以看出,SageAttention3 相較于 FlashAttention2 實現了 4~5 倍的加速,相較于 xformers 實現了 8~11 倍的加速。
(來源:arXiv)
下圖展示了 SageBwd 及其基線模型在英偉達 RTX 4090 上的“正向+反向”傳播的速度。結果表明,SageBwd 相較于 FlashAttention2 最多實現了 1.67 倍的加速,并且比基于 Triton 實現的 FlashAttention2 以及 xformers 具有更高的加速比。
(來源:arXiv)
在下表中,研究團隊使用 SageAttention3 和其他注意力方法比較了各種模型上的端到端質量指標。結果表明,SageAttention3 在這些模型中幾乎不會造成端到端的質量損失。
(來源:arXiv)
為了評估 SageBwd 在訓練任務中的有效性,研究團隊進行了兩個實驗。
首先,研究團隊在 GSM8K、DROP、MMLU 和 HELLASWAG 數據集上對 Qwen2.5(3B)和 Llama3.2(1B)的基礎模型進行微調。下圖顯示了微調損耗結果,表明 SageBwd 與 BF16 完全對齊。
(來源:arXiv)
此外,研究團隊對多個測試數據集上微調模型的答案質量的評估表明,SageBwd 實現了與 BF16 相同的性能。
(來源:arXiv)
其次,研究團隊使用 Llama(400M)模型在 FineWebEdu 上進行預訓練任務。下圖顯示了損耗曲線,表明雖然 SageBwd 可以實現損耗收斂,但其收斂速度相對較慢。這種限制制約了它在預訓練任務中的適用性。
(來源:arXiv)
下圖顯示了視頻生成的一些比較示例,包括使用 SageAttention3 在混元上生成視頻和在 Stable-diffsion3.5 上生成圖像。結果表明,SageAttention3 保持了完好的生成質量。
(來源:arXiv)
下圖總結了端到端推理和訓練延遲的改進情況。結果顯示,相比混元和 CogVideoX,SageAttention3 在英偉達 RTX5090 上實現了約 3 倍和 2.4 倍的端到端推理生成加速。此外,SageBwd 在英偉達 RTX4090 上使用 8K/16K token 微批量訓練 Llama(1B)時,實現了大約 1.15 倍的加速。
(來源:arXiv)
盡管 SageBwd 展現出比 FP16 實現更快的性能,但研究團隊觀察到其當前速度與理論上限之間存在顯著差距。這一差距可能是由 Triton 內核實現不夠優良導致的,研究團隊計劃進一步對其進行優化。研究團隊在論文中表示,探索低比特注意力在預訓練任務中的應用也是一個富有前景的研究方向,非常值得探索。
參考資料:
相關論文:https://.org/pdf/2505.11594
開源代碼:https://github.com/thu-ml/SageAttention
排版:劉雅坤
特別聲明:以上內容(如有圖片或視頻亦包括在內)為自媒體平臺“網易號”用戶上傳并發布,本平臺僅提供信息存儲服務。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.