隨著大型模型需要處理的序列長(zhǎng)度不斷增加,注意力運(yùn)算(Attention)的時(shí)間開(kāi)銷(xiāo)逐漸成為主要開(kāi)銷(xiāo)。此前,清華大學(xué)陳鍵飛團(tuán)隊(duì)提出的即插即用的 SageAttention 和 SageAttention2 已經(jīng)被業(yè)界及社區(qū)廣泛的使用于各種開(kāi)源及商業(yè)的大模型中,比如 Vidu,CogvideoX,Mochi,Wan,HunyuanVideo,F(xiàn)lux,Llama3,Qwen 等。
近日,清華大學(xué)陳鍵飛團(tuán)隊(duì)進(jìn)一步提出了針對(duì) BlackWell 架構(gòu)的首個(gè)全 FP4 量化的即插即用注意力算子(SageAttention3)。實(shí)現(xiàn)了5倍相比于 FlashAttention 的即插即用的推理加速(此前的 SageAttention V1/V2/V2++ 分別達(dá)到了 2.1,3,3.9 倍的加速效果),比如在 RTX 5090 上,SageAttention3 達(dá)到了1040 TOPS的速度,甚至是比 RTX 5090 昂貴十幾倍的 H100 上使用 Hopper 獨(dú)有的 FlashAttention3 還要快 1.65 倍!SageAttention3 在多種視頻和圖像生成等大模型上(包括 HunyuanVideo,CogVideoX,Mochi和各類(lèi)圖像生成模型)均保持了端到端的精度表現(xiàn)。同時(shí)還首次提出可訓(xùn)練的 8 比特注意力(SageBwd)用于大模型的訓(xùn)練加速(注:FlashAttention3 的 FP8 版本也只支持前向傳播),在各項(xiàng)微調(diào)任務(wù)中均保持了與全精度注意力相同的結(jié)果。
- 論文標(biāo)題:SageAttention3: Microscaling FP4 Attention for Inference and An Exploration of8-bitTraining
- 論文鏈接:https://arxiv.org/abs/2505.11594
- 開(kāi)源代碼:https://github.com/thu-ml/SageAttention
效果預(yù)覽
SageAttention3 實(shí)現(xiàn)了高效的 Attention 算子,可以實(shí)現(xiàn)即插即用的 5 倍于 FlashAttention 的推理加速。即輸入任意 Q, K, V 矩陣,可以快速返回 Attention Output (O),真正做到了兩行代碼加速任意模型推理。(注:按照官方倉(cāng)庫(kù)中的開(kāi)源計(jì)劃,SageAttention2++ 的代碼將于6月20日左右開(kāi)源,SageAttention3 的代碼將于7月15日左右開(kāi)源。)
效果上,以 HunyuanVideo 為例,在 RTX5090 上 SageAttention3 可以 3 倍加速端到端的生成,且視頻質(zhì)量幾乎無(wú)損:
https://mp.weixin.qq.com/s/aVlYM_cMrpTKeH3ao-UJuA
視頻 1(使用 FlashAttention2,490s)
https://mp.weixin.qq.com/s/aVlYM_cMrpTKeH3ao-UJuA
視頻 2(使用 SageAttention3,164s)
(注:FlashAttention2 已經(jīng)是在 RTX5090 上最優(yōu)的 FlashAttention 實(shí)現(xiàn)。)
接下來(lái),將從前言,挑戰(zhàn),方法,以及實(shí)驗(yàn)效果四個(gè)方面介紹 SageAttention3。
SageAttention3 總體流程圖
前言
隨著大模型需要處理的序列長(zhǎng)度越來(lái)越長(zhǎng),Attention 的速度優(yōu)化變得越來(lái)越重要。下圖展示了一個(gè)標(biāo)準(zhǔn)的 Transformer 模型中各運(yùn)算的時(shí)間占比隨序列長(zhǎng)度的變化:
為了方便指代注意力運(yùn)算中的矩陣,我們先回顧一下注意力的計(jì)算公式:
研究動(dòng)機(jī):(1)Blackwell 架構(gòu)有著速度極快的 FP4 Tensor Core,以 RTX5090 為例,其速度是 FP16 Tensor Core 的 8 倍。(2)訓(xùn)練階段的注意力運(yùn)算開(kāi)銷(xiāo)也同樣重要,在此之前并沒(méi)有工作嘗試過(guò)低比特注意力加速模型訓(xùn)練,包括 FlashAttention3 的 FP8 版本也只有 Forward 過(guò)程。我們還希望同時(shí)量化注意力的前向 + 反向過(guò)程來(lái)加速訓(xùn)練。
FP4 注意力量化有什么問(wèn)題?
(1)FP4 數(shù)值類(lèi)型僅有 15 個(gè)有效數(shù)值,這使得以 Tensor(Per-tensor)或以 Token(Per-token)粒度的量化都難以有效保證量化的準(zhǔn)確度。
(2)P 矩陣的值分布在 [0, 1] 之間,直接的 FP4 量化會(huì)使量化縮放因子被限制在一個(gè)狹窄的范圍內(nèi)。然而,硬件要求這些量化因子必須采用 FP8 數(shù)據(jù)類(lèi)型表示。此時(shí),將縮放因子轉(zhuǎn)為 FP8 時(shí)會(huì)導(dǎo)致顯著的精度損失。
8-Bit 注意力用于訓(xùn)練有什么問(wèn)題?
(1)P 矩陣的梯度對(duì)量化誤差過(guò)于敏感,并且在反向過(guò)程中還會(huì)沿著序列長(zhǎng)度對(duì) Q 和 K 的梯度造成誤差累積。
技術(shù)方案
為了解決上述的挑戰(zhàn),研究團(tuán)隊(duì)提出了對(duì)應(yīng)的解決辦法。
(1)為了提高 FP4 的量化精度。研究團(tuán)隊(duì)采用了 Microscaling FP4 量化,這是 BlackWell 硬件層面支持的一種量化方式。即可以采用 或 的量化粒度進(jìn)行矩陣量化,NIVIDA 在硬件層面自動(dòng)支持了反量化過(guò)程。此外,Microscaling FP4 有兩種數(shù)據(jù)表示的形式,一種是MXFP4, 另外一種是 NVFP4。兩種格式都采用了 E2M1 的 FP4 數(shù)據(jù)類(lèi)型。不同的是,NVFP4 的量化的塊大小為,縮放因子的數(shù)據(jù)類(lèi)型為 E4M3。MXFP4 的量化的塊大小為,縮放因子的數(shù)據(jù)格式為 E8M0。研究團(tuán)隊(duì)采用了 NVFP4 數(shù)據(jù)格式,因?yàn)槠淞炕瘻?zhǔn)確率遠(yuǎn)高于 MXFP4:
(2)針對(duì) P 的縮放因子范圍狹窄的問(wèn)題,研究團(tuán)隊(duì)提出了兩階段量化(Two-level Quantization)的辦法。FlashAttention 中的 P 矩陣的值在 [0, 1] 的范圍內(nèi),導(dǎo)致 P 的縮放因子的范圍也只在 0~0.167 之間。把縮放因子直接轉(zhuǎn)換為 FP8 格式會(huì)帶來(lái)極大的精度損失。
于是研究團(tuán)隊(duì)決定先把 P 通過(guò) Per-token 量化到 [0, ] 的范圍內(nèi),再進(jìn)行 FP4 的量化:
下表展示了 Two-Level Scaling 對(duì)精度的提升:
下圖展示了 SageAttention3 的算法流程:
(3)在 8-Bit 訓(xùn)練 Attention 當(dāng)中,研究團(tuán)隊(duì)對(duì) Q,K,V 采用了 Per-block INT8 量化,對(duì) P 巧妙地采用了無(wú)量化 Overhead 的 Per-token 量化。前向過(guò)程的算法如下:
在反向傳播的過(guò)程中總共涉及到 5 個(gè)矩陣乘法:
研究團(tuán)隊(duì)發(fā)現(xiàn)是否量化 dOVT 對(duì)精度有著較大的影響:
于是研究團(tuán)隊(duì)將 dOVT 保留為 FP16 精度,而對(duì)其它四個(gè)矩陣乘法進(jìn)行了量化。以下是反向傳播的算法:
實(shí)驗(yàn)效果
SageAttention3 實(shí)現(xiàn)了 GPU 底層的 CUDA Kernel,在算子速度以及各個(gè)模型端到端準(zhǔn)確度上都有十分不錯(cuò)的表現(xiàn)。
具體來(lái)說(shuō),算子速度相比于 FlashAttention2(5090 上最快的 FlashAttention) 和 xformers 有大約 5 倍以及 10 倍的加速:
各模型在真實(shí)場(chǎng)景的端到端精度表現(xiàn)中,在視頻、圖像生成等大模型上均保持了端到端的精度表現(xiàn):
下圖是在 HunyuanVideo 當(dāng)中的可視化實(shí)例:
下圖是在 Flux 上的可視化實(shí)例:
下圖是在 Cogvideo 中的可視化實(shí)例:
下表展示了各個(gè)視頻、圖像生成模型中 SageAttention3 的端到端精度表現(xiàn):
端到端的速度表現(xiàn)上,SageAttention3 的實(shí)現(xiàn)均可以有效地對(duì)長(zhǎng)序列的模型進(jìn)行加速,比如可以端到端 3 倍加速 HunyuanVideo:
8-Bit 訓(xùn)練 Attention 在 Base Model 微調(diào)到 Instruct Model 的任務(wù)上展現(xiàn)出與 BF16 的注意力完全一致的精度表現(xiàn),下表是在多個(gè)不同的任務(wù)以及模型上微調(diào)的結(jié)果:
并且在訓(xùn)練速度上也能起到較好的加速效果:
研究團(tuán)隊(duì)還發(fā)現(xiàn),目前的 8 比特用于訓(xùn)練的 Attention 雖然在微調(diào)任務(wù)上完全無(wú)損,但是在預(yù)訓(xùn)練任務(wù)上與全精度的 Attention 在 Loss 上還有一定差距,需要未來(lái)進(jìn)一步的研究:
特別聲明:以上內(nèi)容(如有圖片或視頻亦包括在內(nèi))為自媒體平臺(tái)“網(wǎng)易號(hào)”用戶(hù)上傳并發(fā)布,本平臺(tái)僅提供信息存儲(chǔ)服務(wù)。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.