僅需幾十行 PyTorch 代碼即可大幅提高 GPU 利用率,在英偉達(dá) A100 上的 GPU 利用率高達(dá) 70%。這一能力由一款名為 LaCT 的新模型架構(gòu)實(shí)現(xiàn),它由北京大學(xué)本科校友、美國(guó)麻省理工學(xué)院博士生張?zhí)爝h(yuǎn)和所在團(tuán)隊(duì)打造。
圖丨張?zhí)爝h(yuǎn)(來(lái)源:https://tianyuanzhang.com/)
研究團(tuán)隊(duì)通過(guò)在不同模態(tài)的任務(wù)中采用范圍從 2000 到 100 萬(wàn) tokens 的大塊更新打造了這種新模型架構(gòu)。該架構(gòu)集成了用于捕捉長(zhǎng)上下文的大塊測(cè)試時(shí)訓(xùn)練,并集成了用于針對(duì)局部結(jié)構(gòu)進(jìn)行建模的窗口注意力機(jī)制。
由于每個(gè)大塊內(nèi)的 tokens 都被視為一個(gè)無(wú)序集,因此研究團(tuán)隊(duì)將窗口注意力集成到 LaCT 中,以便捕獲塊內(nèi)的局部依賴關(guān)系,這讓 LaCT 能夠顯著提高并行性。
這能夠擴(kuò)展非線性快速權(quán)重,從而能夠提高存儲(chǔ)容量。而且,這種簡(jiǎn)單的實(shí)現(xiàn)方式能夠輕松地集成更有效的測(cè)試時(shí)優(yōu)化器(比如 Muon)。
此外,LaCT 的大塊設(shè)計(jì)也能夠很自然地用于針對(duì)各種 N 維數(shù)據(jù)進(jìn)行建模,原因在于它可以將塊大小與數(shù)據(jù)的內(nèi)部結(jié)構(gòu)對(duì)齊,例如將圖像或連續(xù)視頻幀內(nèi)的 tokens 分組為一個(gè)塊。
為了驗(yàn)證本次方法的有效性,研究團(tuán)隊(duì)在不同數(shù)據(jù)模態(tài)和任務(wù)中,包括從圖像集、語(yǔ)言模型和自回歸視頻擴(kuò)散模型中進(jìn)行了新視圖合成。
實(shí)驗(yàn)結(jié)果顯示:研究團(tuán)隊(duì)的模型能夠處理多達(dá) 128 張分辨率為 960×536 的輸入圖像,形成最多 100 萬(wàn) tokens 的序列,并且在此輸入規(guī)模下,在渲染質(zhì)量方面優(yōu)于 3D 高斯?jié)姙R(3D Gaussian Splatting)技術(shù)。
盡管語(yǔ)言數(shù)據(jù)本身并不顯式包含塊狀結(jié)構(gòu),但是與 DeltaNet 等 SOTA 方法相比,研究團(tuán)隊(duì)的模型實(shí)現(xiàn)了大致相當(dāng)?shù)男阅鼙憩F(xiàn)。
研究團(tuán)隊(duì)還通過(guò)將 LaCT 與滑動(dòng)窗口注意力相結(jié)合,將一個(gè) 140 億參數(shù)的雙向視頻擴(kuò)散 Transformer 適配為自回歸模型。這種適配后的模型可以生成包含多達(dá) 56000 個(gè)視覺(jué) tokens 的連貫視頻。
與此同時(shí),在最長(zhǎng)的序列實(shí)驗(yàn)中,他們進(jìn)行了超過(guò) 100 萬(wàn)個(gè)上下文長(zhǎng)度的新視圖合成。
目前,研究團(tuán)隊(duì)已經(jīng)開(kāi)源了代碼和權(quán)重,預(yù)計(jì) LaCT 將能推動(dòng)人們對(duì)于更高效長(zhǎng)上下文建模架構(gòu)的探索(https://tianyuanzhang.com/projects/ttt-done-right/)。
吞吐量開(kāi)銷小至 1%-3%
當(dāng)前,處理長(zhǎng)上下文的需求正在迅速增長(zhǎng)。雖然 softmax 注意力已成為建模各類數(shù)據(jù)的解決方案,但其計(jì)算成本隨序列長(zhǎng)度呈二次方增長(zhǎng),這推動(dòng)了人們對(duì)更高效長(zhǎng)上下文建模的廣泛研究。
最近,測(cè)試時(shí)間訓(xùn)練(TTT,Test-Time Training)已成為一種富有前景的高效二次序列建模方法。測(cè)試時(shí)間訓(xùn)練能將循環(huán)神經(jīng)網(wǎng)絡(luò)中的循環(huán)狀態(tài)概念擴(kuò)展到一個(gè)小型的、在線自適應(yīng)的子網(wǎng)絡(luò)。這個(gè)子網(wǎng)絡(luò)的參數(shù)也被稱為快速權(quán)重,它們通過(guò)自監(jiān)督目標(biāo)在線快速適配,以便記憶上下文中的信息。
近期,多個(gè)團(tuán)隊(duì)均探索了快速權(quán)重網(wǎng)絡(luò)的各種在線目標(biāo)、優(yōu)化器和架構(gòu)。盡管如此,已有的測(cè)試時(shí)訓(xùn)練方法仍然難以有效擴(kuò)展到長(zhǎng)上下文場(chǎng)景,根本原因在于測(cè)試時(shí)訓(xùn)練層的硬件利用率極低,在當(dāng)前 GPU 上硬件利用率通常低于峰值算力的 5%。
這種低效性是由于使用了小批量規(guī)模,即每隔一個(gè) token 或每 16 個(gè)到 64 個(gè) tokens 更新一次快速權(quán)重,之所以這樣做是因?yàn)閭鹘y(tǒng)觀點(diǎn)認(rèn)為這種方式對(duì)于上下文學(xué)習(xí)更加有效。
但是,這種小批量處理方式會(huì)導(dǎo)致并行效率低下以及計(jì)算密度不足,尤其在使用大型非線性快速權(quán)重時(shí),會(huì)給硬件高效實(shí)現(xiàn)帶來(lái)重大挑戰(zhàn),以至于實(shí)際算力利用率難以突破 10% 的有效閾值。基于此,本次研究團(tuán)隊(duì)采用相反的策略并引入了 LaCT。
如下圖所示,LaCT 塊由三種類型的層組成:窗口注意力層、大塊測(cè)試時(shí)訓(xùn)練層和前饋層。
(來(lái)源:arXiv)
每一層都配備了殘差連接,這一設(shè)計(jì)也遵循了 Transformer 架構(gòu)中的標(biāo)準(zhǔn)做法。窗口注意力層通過(guò)執(zhí)行局部自注意力,來(lái)捕捉局部依賴關(guān)系。而在測(cè)試時(shí)訓(xùn)練層,研究團(tuán)隊(duì)則將序列分割成了大塊。
研究團(tuán)隊(duì)表示,歷史上下文通過(guò)“更新”操作逐漸被壓縮到快速權(quán)重中,最新的權(quán)重被“應(yīng)用”到當(dāng)前的查詢向量(Q)上,以便計(jì)算其對(duì)應(yīng)的輸出。前饋層則執(zhí)行與 Transformer 中類似的通道混合操作。
由于測(cè)試時(shí)訓(xùn)練的“更新”操作和“應(yīng)用”操作是解耦的,因此可以自適應(yīng)地設(shè)置塊大小,并以不同的順序應(yīng)用這些操作,進(jìn)而能夠模擬不同類型的數(shù)據(jù)依賴關(guān)系。
當(dāng)分塊大小等于完整序列長(zhǎng)度時(shí),會(huì)先執(zhí)行“應(yīng)用”操作再執(zhí)行“更新”操作,這在概念上與全注意力機(jī)制相似。通過(guò)交替使用“更新”操作和“應(yīng)用”操作,能夠形成分塊因果掩碼,其中分塊大小與塊大小互相對(duì)應(yīng)。在兩個(gè)操作之間切換順序會(huì)導(dǎo)致掩碼發(fā)生偏移,偏移掩碼不會(huì)在塊內(nèi)泄露未來(lái)信息,這在語(yǔ)言建模中構(gòu)建完整因果掩碼時(shí)非常重要。
(來(lái)源:arXiv)
大塊測(cè)試時(shí)訓(xùn)練層會(huì)將數(shù)據(jù)視為集合序列,因?yàn)槠淇焖俚臋?quán)重更新會(huì)忽略每個(gè)塊內(nèi)的 tokens 順序和空間局部性。然而,許多數(shù)據(jù)模態(tài)比如視頻、圖像集合或文本,并不完全符合這種基于集合的視角。對(duì)于這些模態(tài)而言,塊內(nèi)結(jié)構(gòu)和局部性對(duì)于捕獲整體數(shù)據(jù)結(jié)構(gòu)至關(guān)重要。
因此,研究團(tuán)隊(duì)將局部窗口注意力層與測(cè)試時(shí)訓(xùn)練層集成在一起,以便處理塊內(nèi)的數(shù)據(jù)結(jié)構(gòu)。此外,窗口注意力機(jī)制能有效捕捉數(shù)據(jù)中的局部特征。對(duì)于測(cè)試時(shí)訓(xùn)練層來(lái)說(shuō),這讓它能夠?qū)⑵涔潭ù笮〉目焖贆?quán)重容量集中用于建模非局部依賴關(guān)系。
總的來(lái)說(shuō),LaCT 是一種混合架構(gòu),它采用二次計(jì)算注意力機(jī)制來(lái)處理局部結(jié)構(gòu),針對(duì)非局部上下文采用線性計(jì)算的測(cè)試時(shí)訓(xùn)練機(jī)制。上下文并行(CP,Context Parallelism)沿著上下文長(zhǎng)度維度針對(duì)序列進(jìn)行分區(qū),并將分片分布在多個(gè)設(shè)備上來(lái)進(jìn)行并行計(jì)算。
前饋層和窗口注意力均屬于局部操作算子,因此天然地支持上下文并行。對(duì)于測(cè)試時(shí)訓(xùn)練層,小塊難以支持上下文并行,因此更傾向于使用張量并行。
研究團(tuán)隊(duì)的大塊測(cè)試時(shí)訓(xùn)練層通過(guò)在塊內(nèi)分片 tokens 來(lái)實(shí)現(xiàn)上下文并行。在訓(xùn)練新視圖合成時(shí),他們采用了這種并行方法,并觀察到 1% 至 3% 的極小吞吐量開(kāi)銷。與此同時(shí),LaCT 架構(gòu)可以與數(shù)據(jù)并行、流水線并行和張量并行等其他并行策略兼容。
實(shí)驗(yàn)涵蓋:新視圖合成、語(yǔ)言建模和自回歸視頻生成
如前所述,研究團(tuán)隊(duì)開(kāi)展了關(guān)于新視圖合成、語(yǔ)言建模和自回歸視頻生成的實(shí)驗(yàn)。在與線性成本基線方法的對(duì)比實(shí)驗(yàn)中,研究團(tuán)隊(duì)為其增加了相同的窗口注意力模塊,以便確保能夠進(jìn)行公平的比較。
表丨對(duì)每個(gè)實(shí)驗(yàn)中關(guān)鍵因素的總結(jié)(來(lái)源:arXiv)
在新視圖合成上,研究團(tuán)隊(duì)在場(chǎng)景級(jí)和物體級(jí)數(shù)據(jù)集上對(duì)本次方法進(jìn)行評(píng)估。他們使用 Objaverse 數(shù)據(jù)集進(jìn)行物體級(jí)訓(xùn)練,并遵循 LVSM 和 GS - LRM 的設(shè)置。
訓(xùn)練完成之后,研究團(tuán)隊(duì)在 Google Scanned Objects(GSO)數(shù)據(jù)集上進(jìn)行評(píng)估,該數(shù)據(jù)集的分辨率分別為 256×256 和 512×512。每次評(píng)估涉及 4 到 48 個(gè)輸入視圖,且每個(gè)物體有 8 個(gè)新視圖。
對(duì)于場(chǎng)景級(jí)評(píng)估,研究團(tuán)隊(duì)采用挑戰(zhàn)性較高的 DL3DV 場(chǎng)景數(shù)據(jù)集,其中包含超過(guò) 11000 個(gè)訓(xùn)練場(chǎng)景和 140 個(gè)測(cè)試場(chǎng)景,每個(gè)場(chǎng)景大約有 300 個(gè)視圖,評(píng)估的分辨率為 960 × 536。
對(duì)于物體級(jí)評(píng)估,研究團(tuán)隊(duì)使用了如下兩個(gè)基線模型:全注意力模型和寄存器注意力模型。
全注意力基線模型將測(cè)試時(shí)訓(xùn)練層替換為逐塊因果注意力層,實(shí)現(xiàn)了輸入 tokens 之間的雙向交互和來(lái)自新視圖的交叉注意力。
寄存器注意力模型將輸入 tokens 壓縮到 4096 個(gè)寄存器中,并通過(guò)與這些寄存器的交叉注意力解碼新視圖。
在場(chǎng)景級(jí)評(píng)估中,研究團(tuán)隊(duì)與 LongLRM 進(jìn)行對(duì)比,LongLRM 是一種結(jié)合了 Mamba 和全注意力機(jī)制的模型,可用于 3D 高斯濺射(3D Gaussian splat)預(yù)測(cè)。此外,他們還與純基于優(yōu)化的 3D 高斯濺射方法進(jìn)行了對(duì)比。
表丨對(duì)所有模型計(jì)算復(fù)雜性的總結(jié)(來(lái)源:arXiv)
在性能評(píng)估上,研究團(tuán)隊(duì)采用每 tokens 損失度量來(lái)評(píng)估模型有效使用完整上下文的能力。出現(xiàn)單調(diào)遞減的損失表示上下文利用成功,而處于平穩(wěn)狀態(tài)則表示上下文使用有限。
另?yè)?jù)悉,他們從原始 LaCT 塊中移除了窗口注意力層,將滑動(dòng)窗口注意力(SWA,sliding window-attention)層直接集成到大塊測(cè)試時(shí)訓(xùn)練層中,并將模型與全注意力模型、門控線性注意力(GLA,Gated Linear Attention)和 DeltaNet 進(jìn)行了比較。
為了確保公平性,他們?yōu)?GLA 和 DeltaNet 都增強(qiáng)了相同的滑動(dòng)窗口注意力層,并采用 100 萬(wàn)的 RoPE 庫(kù)進(jìn)行 32K tokens 上下文的培訓(xùn)。
表丨對(duì)所有方法機(jī)制和訓(xùn)練吞吐量的總結(jié)(來(lái)源:arXiv)
為了比較塊遞歸和逐 tokens 遞歸,在條件受控的實(shí)驗(yàn)中,研究團(tuán)隊(duì)的線性大塊遞歸策略在相同狀態(tài)大小下優(yōu)于線性逐 tokens 遞歸策略在。
由于語(yǔ)言本身并不天然存在塊狀結(jié)構(gòu),研究團(tuán)隊(duì)提出的線性大塊遞歸變體在初始階段性能不如 GLA 和 DeltaNet 等逐 token 方法。然而,當(dāng)將其與大規(guī)模非線性狀態(tài)以及 Muon 優(yōu)化器相結(jié)合時(shí),該變體的表現(xiàn)將超越這些逐 token 方法。
總的來(lái)說(shuō),本次成果凸顯了大塊測(cè)試時(shí)訓(xùn)練在計(jì)算效率和性能上的優(yōu)勢(shì),為更高效且可擴(kuò)展的長(zhǎng)上下文序列建模鋪平了道路。
通過(guò)消除對(duì)于低級(jí)硬件特定實(shí)現(xiàn)的依賴,LaCT 使人們能夠更廣泛地探索架構(gòu)設(shè)計(jì)空間。未來(lái),研究團(tuán)隊(duì)希望這項(xiàng)工作能夠啟發(fā)并加速長(zhǎng)上下文建模和測(cè)試時(shí)訓(xùn)練領(lǐng)域的新研究。
參考資料:
https://arxiv.org/abs/2505.23884
運(yùn)營(yíng)/排版:何晨龍
特別聲明:以上內(nèi)容(如有圖片或視頻亦包括在內(nèi))為自媒體平臺(tái)“網(wǎng)易號(hào)”用戶上傳并發(fā)布,本平臺(tái)僅提供信息存儲(chǔ)服務(wù)。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.