機(jī)器之心報(bào)道
編輯:欣東、陳陳
本文介紹了一種名為「嫁接」的技術(shù),用于在小計(jì)算預(yù)算下通過編輯預(yù)訓(xùn)練 Diffusion Transformers(簡(jiǎn)稱 DiTs)來探索新的模型架構(gòu)設(shè)計(jì)。這種方法允許研究者在不從頭開始訓(xùn)練模型的情況下,通過替換模型中的某些算子(如 MLP)來創(chuàng)建新的混合架構(gòu),從而在保持模型質(zhì)量的同時(shí)減少計(jì)算量。
模型架構(gòu)設(shè)計(jì)在機(jī)器學(xué)習(xí)中扮演著核心角色,與數(shù)據(jù)、算法、算力和基準(zhǔn)測(cè)試一樣重要。它定義了模型函數(shù)、算子選擇(如注意力機(jī)制、卷積)和配置設(shè)定(如模型深度、寬度)等等模型要素。
盡管如此,由于從頭訓(xùn)練模型的成本過高 —— 尤其人們難以獲得關(guān)于架構(gòu)設(shè)計(jì)的深刻洞見(即哪些方案有效、哪些無效)。因此,研究新架構(gòu)仍是一項(xiàng)挑戰(zhàn),對(duì)生成模型而言尤為如此。
在本文中,來自斯坦福大學(xué)、 Liquid AI 等機(jī)構(gòu)的研究者探索了這一問題,即對(duì)預(yù)訓(xùn)練模型進(jìn)行架構(gòu)編輯來研究新架構(gòu)。
- 論文鏈接:https://arxiv.org/pdf/2506.05340v1
- 論文主頁:https://grafting.stanford.edu/
- 論文標(biāo)題: Exploring Diffusion Transformer Designs via Grafting
具體而言,該研究提出了一種編輯預(yù)訓(xùn)練擴(kuò)散 transformer(DiT)的簡(jiǎn)單方法,即 Grafting(嫁接),該方法可以在較小的計(jì)算預(yù)算下實(shí)現(xiàn)新的架構(gòu)。
嫁接過程如下:
(i)激活蒸餾:此階段通過回歸目標(biāo)(regression objective)蒸餾原始算子的激活特征,將其功能遷移至新算子。該階段核心在于實(shí)現(xiàn)算子間的功能傳遞。
(ii)輕量級(jí)調(diào)優(yōu):此階段通過使用有限的數(shù)據(jù)進(jìn)行調(diào)優(yōu),減輕了由于集成多個(gè)新算子而導(dǎo)致的誤差傳播。
此外,架構(gòu)編輯還涵蓋多種策略,如添加、刪除和替換算子。
本文還基于 DiT-XL/2 構(gòu)建了一個(gè)測(cè)試平臺(tái),以研究嫁接對(duì)模型質(zhì)量的影響。
利用該測(cè)試平臺(tái),本文通過嫁接技術(shù)開發(fā)了一系列混合設(shè)計(jì):用門控卷積、局部注意力和線性注意力取代 Softmax 注意力,用可變擴(kuò)展率和卷積變體取代 MLP。
值得注意的是,許多混合設(shè)計(jì)使用不到 2% 的預(yù)訓(xùn)練計(jì)算資源就實(shí)現(xiàn)了良好的質(zhì)量(FID:2.38–2.64,而 DiT-XL/2 為 2.27)。然后,本文嫁接了一個(gè)文本轉(zhuǎn)圖像模型 (PixArt-Σ),實(shí)現(xiàn)了 1.43 倍的加速,而 GenEval 分?jǐn)?shù)下降不到 2%。
最后,本文展示了一個(gè)案例研究,該研究通過嫁接技術(shù)將每對(duì)序列 Transformer 模塊轉(zhuǎn)換為并行模塊,從而重構(gòu)了 DiT-XL/2。這將模型深度減少到原來一半,并獲得了比其他同等深度模型更高的質(zhì)量(FID:2.77)。
總而言之,該研究展示了可以通過預(yù)訓(xùn)練 DiT 來探索新的擴(kuò)散模型設(shè)計(jì),其修改范圍涵蓋從算子替換到架構(gòu)重構(gòu)。
嫁接擴(kuò)散 Transformer
兩階段嫁接方法
嫁接旨在通過編輯預(yù)訓(xùn)練模型的計(jì)算圖來實(shí)現(xiàn)新架構(gòu)。由于該研究專注于用替代方案替換現(xiàn)有算子,這引出了兩個(gè)問題:
問題 1:在將新算子集成到計(jì)算圖之前,應(yīng)該如何初始化?
對(duì)應(yīng)第一階段:通過激活蒸餾進(jìn)行初始化。由于 DiT 的激活是連續(xù)且平滑的,這可以被視為一個(gè)回歸問題:
問題 2:當(dāng)多個(gè)算子集成到計(jì)算圖時(shí),如何減輕誤差傳播?
對(duì)應(yīng)第二階段:輕量級(jí)調(diào)優(yōu)。隨著更多算子被替換,初始化誤差會(huì)不斷傳播,導(dǎo)致與預(yù)訓(xùn)練模型的行為出現(xiàn)偏差。
本文采用端到端微調(diào)來緩解階段 1 的累積誤差。微調(diào)目標(biāo)函數(shù)如公式 1 所示。
實(shí)踐中,本文發(fā)現(xiàn),即使替換 DiT-XL/2 中的所有 MHA 或 MLP 層,僅使用 10% 的訓(xùn)練數(shù)據(jù)也能恢復(fù)競(jìng)爭(zhēng)性能。
自嫁接基準(zhǔn)
在研究新的架構(gòu)設(shè)計(jì)之前,該研究引入了自嫁接(self-grafting),這是一種簡(jiǎn)單的對(duì)照設(shè)置:將現(xiàn)有算子(如 MHA、MLP)替換為相同類型但權(quán)重隨機(jī)初始化的算子。這樣可以保持計(jì)算圖的結(jié)構(gòu) —— 包括算子類型和參數(shù)數(shù)量 —— 但改變了具體的計(jì)算過程。自嫁接有三方面作用:(1)評(píng)估在不改變架構(gòu)的情況下嫁接流程本身的效果;(2)為比較不同的替換方案提供一個(gè)性能基準(zhǔn);(3)研究影響性能的因素,如數(shù)據(jù)規(guī)模、回歸目標(biāo)和超參數(shù)。
激活行為分析以及自嫁接結(jié)果
本文首先分析了 DiT-XL/2 層中的 MHA 和 MLP 算子激活行為。在這兩種情況下,本文觀察到激活值存在較大差異,尤其是在較深的層中(表 1 (i, ii))。
經(jīng)過分析,本文得出通過選擇特定于算子的回歸目標(biāo),可以實(shí)現(xiàn)高質(zhì)量的初始化。
如表 1 (iii,iv) 所示,回歸目標(biāo)的選擇會(huì)影響性能。對(duì)于 MHA,L1 實(shí)現(xiàn)了最佳 FID(2.51),其次是 Huber(2.55)和 L2(2.58)。對(duì)于 MLP,L2 表現(xiàn)最佳(2.33),而 L1 表現(xiàn)不佳(2.83);值得注意的是,MLP 的參數(shù)量是 MHA 的 2 倍。
這表明高質(zhì)量的初始化需要量身定制的、激活感知的策略。
研究還發(fā)現(xiàn),使用 10% 的數(shù)據(jù)進(jìn)行完全自嫁接可實(shí)現(xiàn)接近基線的性能。表明在適度的數(shù)據(jù)和計(jì)算預(yù)算下完全自嫁接是可行的。
實(shí)驗(yàn)
實(shí)驗(yàn) I:通過嫁接實(shí)現(xiàn)混合架構(gòu)
本節(jié)實(shí)驗(yàn)圍繞這個(gè)問題進(jìn)行:當(dāng)現(xiàn)有算子被高效的替代方案取代時(shí),我們能否保持模型質(zhì)量?
為了探究這個(gè)問題,本文研究了以下嫁接過程:
1. 待替換算子的類型 ——MHA 或 MLP;
2. 替換算子的類型 —— 例如卷積;
3. 層選擇策略 —— 替換所有層中的算子或使用啟發(fā)式選擇;
4. 替換率 —— 全部替換或部分替換。
為了實(shí)驗(yàn),該研究構(gòu)建了一個(gè)測(cè)試平臺(tái),并提出兩種層選擇策略:完全替換和交錯(cuò)替換。測(cè)試平臺(tái)詳見表 3。
此外,該研究還引入了 Hyena-X 和 Hyena-Y 兩種新的高效門控卷積算子,并設(shè)計(jì)為 MHA 的直接替代品。Figure 3 展示了它們的結(jié)構(gòu)。
MHA 結(jié)果。通過嫁接替換 DiT-XL/2 中的 MHA 算子,獲得了良好的質(zhì)量 - 效率權(quán)衡。主要發(fā)現(xiàn)如下:
在交錯(cuò)嫁接下,較小的感受野表現(xiàn)出驚人的效果。實(shí)驗(yàn)發(fā)現(xiàn),在 50% 交錯(cuò)替換比例下,滑動(dòng)窗口注意力(SWA)、Hyena-X/Y 和 Mamba-2 等替代方案均能保持 FID 分?jǐn)?shù)與基線(2.27)差距在 0.5 以內(nèi)。尤其值得注意的是,盡管 SWA 和 Hyena 變體的感受野有限(卷積核 K=4 / 窗口 w=4),其 FID 下降幅度卻極小。
替換策略:交錯(cuò)替換 vs. 完全替換。將交錯(cuò)替換比例從 50% 提升至 75% 時(shí),性能通常下降,但 SWA 在 75% 交錯(cuò)替換下仍有效(FID=3.09)。100% 替換時(shí),性能急劇惡化(所有 FID > 75),這與局部性分析一致,表明只有部分層是局部且適合嫁接的。
數(shù)據(jù)規(guī)模和層選擇的消融實(shí)驗(yàn)結(jié)果。
MLP 結(jié)果顯示通過嫁接的方式替換 MLP 算子是有效的。
經(jīng)過實(shí)驗(yàn),得出要點(diǎn) 1:嫁接對(duì)于在較小的計(jì)算預(yù)算下構(gòu)建具有良好生成質(zhì)量的高效混合架構(gòu)非常有效。交錯(cuò)設(shè)計(jì)尤其有效。
實(shí)驗(yàn) II:通過嫁接改進(jìn)文本到圖像的擴(kuò)散 Transformers
結(jié)果。嫁接模型在實(shí)時(shí)計(jì)算速度(wall-clock time)上實(shí)現(xiàn)了 1.43 倍的提升,同時(shí)生成評(píng)估分?jǐn)?shù)(GenEval)僅出現(xiàn)小幅下降(47.78 vs. 49.75)。特定屬性的指標(biāo)(Attribute-specific metrics)基本保持可比,并且定性樣本也展現(xiàn)出良好的對(duì)齊度和質(zhì)量。在一些紋理區(qū)域觀察到了局部性的失真(artifacts),這可能是由于 LoRA 的適應(yīng)能力以及所使用的合成數(shù)據(jù)質(zhì)量不高所致(失敗案例詳見圖 D.3,D.4)
要點(diǎn) 2:在文生圖 DiTs 中成功應(yīng)用嫁接技術(shù),構(gòu)建的混合架構(gòu)在實(shí)現(xiàn)顯著加速的同時(shí),生成質(zhì)量損失極小。
了解更多內(nèi)容,請(qǐng)參考原論文。
特別聲明:以上內(nèi)容(如有圖片或視頻亦包括在內(nèi))為自媒體平臺(tái)“網(wǎng)易號(hào)”用戶上傳并發(fā)布,本平臺(tái)僅提供信息存儲(chǔ)服務(wù)。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.