金磊 發自 凹非寺量子位 | 公眾號 QbitAI
天下苦大模型矩陣乘法久矣。
畢竟不論是訓練還是推理過程,矩陣乘法作為最主要的計算操作之一,往往都需要消耗大量的算力。
那么就沒有一種更“快、好、省”的方法來搞這事兒嗎?
有的,香港中文大學最新一篇僅10頁的論文,便提出了一種新算法:
- 能源可節省:5%-10%
- 時間可節省:5%
論文作者之一的Dmitry Rybin表示:
- 這項研究對數據分析、芯片設計、無線通信和LLM訓練都有著深遠的影響!
這么算矩陣乘法,更快!
矩陣乘法是計算機科學和數值線性代數中的核心問題之一。
自從Strassen和Winograd的開創性工作以來,研究者們一直在探索如何減少矩陣乘法所需的計算量。
盡管這類運算在統計、數據分析、深度學習和無線通信等領域有著廣泛應用,例如協方差矩陣的計算和線性回歸中的關鍵步驟,但對于具有特殊結構的矩陣乘法(如計算矩陣與其轉置的乘積XXt)的研究相對較少。
從理論角度看,計算XXt與一般矩陣乘法具有相同的漸近復雜度,因此只能通過常數因子優化來提升速度。
因此,這篇論文《XXt Can Be Faster》提出了一種名為RXTX的新算法,通過結合機器學習搜索方法和組合優化技術,顯著提升了XXt的計算效率。
我們先來了解一下RXTX。
整體來看,這個基于4×4分塊矩陣的遞歸乘法,通過機器學習搜索與組合優化相結合的方法發現。
算法主要包含以下關鍵步驟:
- 分塊與遞歸調用
- :將矩陣X劃分為16個4×4子塊,通過8次遞歸調用處理子問題,并計算26個一般矩陣乘積m1至m26。
- 對稱乘積計算
- :直接計算8個子塊的對稱乘積s1至m8。
- 結果組合
- :通過線性組合上述乘積結果,得到最終的XXt矩陣各分塊元素C11至C44。
與此前最先進的算法(基 Strassen的遞歸分治)相比,RXTX的遞歸關系式為 R(n)=8R(n/4) + 26M(n/4),而原算法為 S(n) = 4S(n/2) + 2M(n/2)。
這一設計使得RXTX的漸近乘法常數為 26/41≈0.6341,比原算法的2/3≈0.6667降低了約5%。
接下來,我們來看下乘法次數與運算總量分析。
通過論文中的定理1的推導,RXTX的乘法次數表達式為:
實驗數據表明,當n為4的冪次時,RXTX的乘法次數比原算法低5%,且隨著n增大,這一優勢持續保持:
通過優化加法步驟(利用公共子表達式減少加法次數),RXTX的總運算量表達式為:
而原算法的總運算量包含對數項,導致其增長更快。
實驗顯示,當n≥256時,RXTX的總運算量優于原算法;當n≥1024時,顯著優于樸素算法:
在6144×6144矩陣的測試中,RXTX的平均運行時間為2.524秒,比BLAS的默認實現快9%,且在99%的測試中表現更優:
盡管運行時間受硬件和內存管理影響,但理論分析表明,當n≥256時,RXTX即可展現速度優勢。
值得一提的是,RXTX的發現得益于機器學習與組合優化的結合,具體流程如下:
- RL代理生成候選乘積:通過強化學習策略生成大量可能的秩-1雙線性乘積。
- MILP枚舉與篩選:
- MILP-A:枚舉候選乘積與目標表達式(XXt的各分塊)之間的線性關系。
- MILP-B:選擇最小的乘積子集,確保所有目標表達式可通過線性組合表示。
- 大鄰域搜索迭代:通過迭代優化,逐步減少冗余乘積,提升算法效率。
這一方法借鑒了AlphaTensor的思路,但通過限制候選空間為二維張量,顯著降低了計算復雜度,使得MILP求解器(如 Gurobi)能夠高效處理。
論文地址:
https://arxiv.org/abs/2505.09814
參考鏈接:
[1]https://x.com/DmitryRybin1/status/1923349883945181392
[2]https://x.com/vikhyatk/status/1923541713618129273
特別聲明:以上內容(如有圖片或視頻亦包括在內)為自媒體平臺“網易號”用戶上傳并發布,本平臺僅提供信息存儲服務。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.