該工作第一作者為螞蟻技術(shù)研究院副研究員胡翔,螞蟻技術(shù)研究院高級研究員武威為通訊作者。
在大語言模型如火如荼的當下,長文本建模仍然是一個極具挑戰(zhàn)的問題。糾其根源,一方面在于主流 LLMs 的架構(gòu) Transformers 中平方復雜度及隨序列長度線性增長的推理階段顯存開銷;另一方面在于 full-attention 有限的外推能力,難以泛化到遠超預訓練階段長度的輸入。
而高效處理長上下文能力,除了簡單的工業(yè)界降本增效的需求外,還涉及通用人工智能 (AGI) 的核心問題:具有永久記憶的智能體。如果將人類從出生開始接收到的信息視作長上下文,人類擁有記憶無非是訪問這些上下文。因此記憶可以看作是超長上下文訪問能力,而擁有與用戶所有對話記憶的智能體,很可能為大語言模型公司構(gòu)建數(shù)據(jù)護城河 (事實上,OpenAI 已經(jīng)開放了類似能力)。
近日,螞蟻的研究團隊為這個問題帶來了一個新思路。就像人類開卷考試只會挑和當前問題相關(guān)的關(guān)鍵頁作為參考,語言模型也可以只關(guān)注與當前上下文相關(guān)的過去片段。以此為出發(fā)點,他們提出一種基于因果檢索的注意力機制 GCA (Grouped Cross Attention),完全端到端地學習如何從上文檢索并挑選最相關(guān)片段,從而實現(xiàn)超長序列高性能處理與泛化能力。人類記憶的另一個特性是大部分時候記憶處于沉睡狀態(tài),相關(guān)記憶片段只會在激活時進入意識。類似地,GCA 通過將上文信息卸載到 CPU / 磁盤,只在需要的時候動態(tài)加載需要的片段到 GPU 的方式,大幅降低了長文本處理的顯存開銷。
目前,GCA 的 Triton kernel 實現(xiàn)已全部開源,相關(guān)論文已被 ICML 2025 接收。
- 論文標題:Efficient Length-Generalizable Attention via Causal Retrieval for Long-Context Language Modeling
- 論文地址:https://arxiv.org/abs/2410.01651
- GitHub 主頁:https://github.com/ant-research/long-context-modeling
實驗結(jié)果也令人振奮:整合 GCA 的模型不僅在長文本數(shù)據(jù)集上展現(xiàn)了更優(yōu)的 perplexity,更展現(xiàn)了 1000 倍以上的長度泛化能力,在 16K 上下文預訓練的模型可在 16M 長上下文密鑰檢索 (passkey retrieval) 實現(xiàn) 100% 準確率,并在更復雜的多跳檢索任務持續(xù)展現(xiàn)了超強外推能力。此外長度泛化與檢索能力效果拔群,基于 GCA 的模型訓練開銷隨序列長度幾乎呈線性關(guān)系,并且推理的顯存開銷接近常數(shù),同時基本持平 Transformers 推理速度。
值得一提的是,本工作 24 年 10 月在 arXiv 公開后,國產(chǎn)之光 DeepSeek 在 25 年初公開了 NSA,兩者思路都是通過挑選過去 chunk 并 attention 的方式實現(xiàn)性能優(yōu)化。但各有側(cè)重,GCA 核心亮點在于超長的長度泛化,NSA 通過巧妙的 kernel 設(shè)計實現(xiàn)了逐 token 的稀疏 attention。受 NSA 的啟發(fā),GCA 的后繼工作 HSA (https://arxiv.org/abs/2504.16795) 結(jié)合了兩者的優(yōu)點進行了融合。
長文本處理難點及現(xiàn)有方案的局限性
近年來,有不少工作討論 Transformers (TRMs) 架構(gòu)如何高效處理長文本。因為基于全量上文 attention 的 TRMs 有一個很顯著的局限:輸入長度超過預訓練長度一定程度后,perplexity 會飆升,無法生成正常文本。如果只是解決正常生成的問題,一個最簡單的思路是滑動窗口注意力,即每個 token 僅關(guān)注最鄰近的 N 個 token 即可。這種方式可以保證 LLMs 持續(xù)生成,但它犧牲了長程信息獲取能力。
另一種思路是認為 attention 窗口擴大到預訓練長度范圍之外后會導致原本的 attention 權(quán)重分布發(fā)生變化,因此通過調(diào)整 softmax 溫度的方式進行長度泛化。但這類方法經(jīng)實驗驗證往往泛化的倍率也有限。
因此,attention 長度泛化的難點在于處理超長序列的同時,能夠真正有效利用上文中的信息。
GCA: 基于端到端因果檢索的注意力機制
現(xiàn)有一些工作通過檢索增強 (RAG) 的思路來進行長文本建模,其基本思路是將文本分段,譬如每 64 個 token 為一個 chunk;每生成一個 chunk 后,模型根據(jù)當前上文信息檢索歷史 chunk 來輔助下一個 chunk 的生成。理想情況下,只要能檢索到對下文生成最有幫助的 chunk,再通過 cross-attention 機制從相關(guān) chunk 收集信息即可。但通常檢索模塊是單獨訓練的,只能檢索到相似內(nèi)容,無法保證挑選對下文生成最有幫助的 chunk。
和已有工作相比,GCA 的一個顯著優(yōu)勢是能夠與自回歸語言模型聯(lián)合預訓練,從而實現(xiàn)端到端學習。
上圖對比了 GCA 與傳統(tǒng)檢索方式的運作區(qū)別。傳統(tǒng)方式中 (a), 檢索模塊檢索并返回相關(guān) chunk,但檢索分只用于挑選 chunk 完全不參與 forward 運算,因此無法獲得梯度,無法學習。GCA 的核心創(chuàng)新在于通過一種兩階段的注意力機制,使得每個 chunk 的檢索分能參與到自回歸預測中,如圖中(b)所示。
1. 分組注意力機制
不同于 (a) 中直接將 chunk 拼接在一起進行 attention, GCA 分別對每個 chunk 進行 attention (分組 attention),從各個 chunk 收集 token 粒度的信息并整合,作為每個 chunk 整體的信息。
2. Chunk-level 信息融合
GCA 將每個 chunk 的檢索相關(guān)分通過 softmax 得到一個概率分布,將其作為權(quán)重對第一步所有 chunk 的表征進行加權(quán)求和,融合所有 chunk 信息用于下一個 token 預測。在反向傳播過程中,更有助于預測下文的 chunk 將被分配更大的權(quán)重,從而實現(xiàn)檢索模塊的端到端學習。
模型整體架構(gòu)是通過 GCA 與 sliding window attention 結(jié)合實現(xiàn)長上下文建模;前者負責長程信息檢索,后者負責整合短程信息。為了進一步提升 GCA 性能,降低顯存開銷,研究團隊將整個 GCA 封裝成由 Triton 實現(xiàn)的 kernel,方便未來工作可以直接復用。
實驗結(jié)果
在語言模型,長程檢索等任務上的實驗表明:
1. 基于 GCA 的 128M 的模型在大海撈針任務即可超越大部分主流 7B 模型,達成 1000 倍外推,實現(xiàn) 16M 上下文的完美大海撈針
在該實驗中,所有模型都僅在不超過 16K 的上下文進行預訓練,baseline 囊括了包含 sliding window attention 等主流注意力機制。基于 GCA 的模型無論在簡單大海撈針,還是更復雜的變量追蹤任務,都保持了穩(wěn)定的外推能力。
注意到幾乎所有 baseline 在上下文長度超過 64K 后幾乎都歸零,這些不同模型存在不同原因。劃窗注意力因為只能看最鄰近的 token,無法實現(xiàn)長程信息獲取;基于循環(huán)結(jié)構(gòu)的由于所有上下文信息都被壓縮在一個固定維度的表征,必然存在信息損失的問題;基于單獨訓練檢索器的模型 (RPTContriever) 的結(jié)果進一步驗證了檢索模型未必能檢索到對下文有幫助的上文。
這一結(jié)果經(jīng)驗性地為可長度泛化的注意力機制提供了一個成功的概念原型。同時證明可泛化的長程信息獲取能力取決于注意力機制原理上的改進,與參數(shù)量的提升無關(guān)。
在摘要及 RULER 榜單的效果
2. 預訓練高效,推理時顯存開銷接近常數(shù):GCA 是一種 sparse attention,其 attention 的視野域保持常數(shù),因此在 batch size 一定的情況下,訓練開銷幾乎與序列長度呈線性。由于 GCA 在生成階段將所有上文的 KV cache 都卸載到 CPU,每次檢索的時候才把相關(guān) chunk 的 kv cache 載入 GPU,因此超長上文也不會有 KV cache 顯存爆炸的問題。而 GPU-CPU 的交換控制在每 64 個 token 一次,因此對推理速度影響非常小,從而實現(xiàn)接近常數(shù)的顯存開銷,但仍保持高效的推理速度及長程信息獲取能力。
訓練時間及 ppl 隨序列長度的變化
推理速度與顯存開銷相比基線 (基于劃窗注意力的 Transformers) 的倍率關(guān)系(越低越好)
相同條件不同模型各個參數(shù)規(guī)模下的訓練吞吐量,相比劃窗注意力有額外 20% 的開銷,但帶來超長程信息獲取的能力
3. 在 arXiv-math 上的數(shù)據(jù)分析發(fā)現(xiàn),通過 GCA,語言模型會根據(jù)當前上下文,檢索下文生成中可能會用到的引理及變量聲明。這說明 GCA 學到的不僅僅是字面相似性,更包含了語義乃至邏輯相關(guān)性。
黑體是當前 chunk,紅色,藍色,黃色,分別代表 top3 相關(guān) chunk、
結(jié)語
本工作提出一種可以長度泛化的稀疏注意力機制 GCA, 其核心在于可導的檢索模塊,可以有效處理 1000 倍于預訓練長度的文本,首次實現(xiàn)在 16M 長度完美的大海撈針。雖然當前實驗的模型規(guī)模較小,但期望該工作可以為機器如何實現(xiàn)永久記憶提供新的研究思路。
特別聲明:以上內(nèi)容(如有圖片或視頻亦包括在內(nèi))為自媒體平臺“網(wǎng)易號”用戶上傳并發(fā)布,本平臺僅提供信息存儲服務。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.