Are Your Continuous Approximations Really Continuous?Reimagining VI with Bitstring Representations
質疑連續近似:以比特串重塑變分推斷
https://openreview.net/pdf?id=SaidYid1Mq
摘要
在大規模模型中高效地執行概率推理是一項重大挑戰,這主要是由于計算需求高以及模型參數的連續性所致。與此同時,機器學習社區也在努力對大規模模型的參數進行量化,以提高其計算效率。我們在此基礎上提出了一種通過變分推斷(VI)來學習量化參數的概率分布的方法。這種方法使得在離散空間中也能有效地學習連續分布。
我們考慮了二維密度估計和量化神經網絡,并引入了一種使用概率電路(probabilistic circuits)的可處理學習方法。該方法為管理復雜分布提供了一個可擴展的解決方案,并能清晰地揭示模型的行為特征。我們在多種設置下驗證了我們的方法,證明了其有效性。
1. 引言
概率推理是現代機器學習的核心,它為在不確定性下的推理提供了有原則的框架。在貝葉斯推理中,不確定性通過參數的后驗概率分布來刻畫。然而,精確的貝葉斯推理通常是難以計算的,因此需要近似求解。
變分推斷(Variational Inference, VI;Blei等,2017;Jordan等,1999;Wainwright和Jordan,2008)通常被用于這一任務,作為一種可擴展的解決方案。但VI的一個局限在于它依賴于連續參數化形式,且常常受限于高斯假設,這可能在表示能力和計算效率上帶來問題,尤其是在大規模場景中。
為了應對計算上的限制,機器學習界越來越多地采用參數量化技術。這些方法通過降低數值精度來提升效率,利用低比特位表示來進行存儲和計算。最近的一些研究顯示,即使使用FP8混合精度(如Liu等,2024)、FP4(Wang等,2025),甚至1比特神經網絡架構(Ma等,2024),也能取得出人意料的良好性能。
于是引出了一個有趣的問題:我們是否也可以在量化參數的離散表示空間中直接進行概率推理?
作為一個初步的思想實驗,請參考圖1,它展示了如何將一個通常用高精度浮點數表示的高斯混合模型,用低精度的比特串表示等價表達。
本文提出了BitVI,一種在比特串模型中進行近似概率推理的新方法。BitVI 利用了數值表示固有的離散特性,在比特串空間中直接逼近連續分布。通過結合概率電路(Probabilistic Circuits, PCs;Choi等,2020),我們的方法提供了一種可在復雜分布上進行學習與推理的可操作方式,而無需高精度表示。
圖2 展示了 BitVI 僅需4比特精度即可建模復雜分布的能力。我們在以下兩個方面驗證了 BitVI 的效果:
(i)標準基準密度函數,展示其對已知分布的逼近能力;
(ii)貝葉斯深度學習中的神經網絡模型(基于Bayesian Benchmarks),其中 BitVI 實現了可擴展且直接的不確定性量化。
我們的結果突顯了 BitVI 在效率與準確性方面的優勢,使其成為傳統推理方法之外的一種有力替代方案。
2. 方法
在計算機上執行計算時,感興趣的參數不可避免地將以離散形式表示。每一個實數值都會通過計算機硬件中的一系列比特串(bitstrings)來表示,并通過一個映射函數 → R 映射到實數軸上,其中 B 表示比特位的數量,映射方式由所選用的數值系統決定(例如定點數表示法,見圖5)。
因此,任何在計算機中表示的“連續”分布 p 或 q ,都可以被表達為關于比特串的離散分布。接下來,我們將說明如何將連續分布表示為關于比特串的一個具有可處理性和靈活性的變分族(variational family)。
2.1 BitVI:在比特串表示上的變分分布
設 q? 是一個定義在二進制字符串的可測空間 (Y, A) 上的概率分布,其概率測度為 Q? ,其中 A 為對應的 σ-代數。再設 (R, B) 是實數空間,B 為其波萊爾 σ-代數(Borel σ-algebra)。進一步地,定義一個可測映射函數?: Y → R,它根據所選的數值系統(例如定點表示法,見圖5),將每個二進制字符串映射為一個實數。
通過該映射函數 ?,我們可以定義在 (R, B) 上的誘導概率測度 Q ,即 Q? 在 ? 下的前推測度(pushforward measure)。具體來說,對于任意一個波萊爾集 B ∈ B ,我們有:
因此,通過指定一個關于比特串的分布 q? 和相應的數值系統,我們就可以獲得一個在實數軸上的誘導變分分布 q。
接下來,我們的目標是找到變分分布的一個參數化形式 θ,使其與某個真實分布之間的 KL 散度最小化(見公式 (5))。當使用一個確定性概率電路來表示 q 時,參數 θ 對應于電路中所有權重 {wi}i 的集合。
需要注意的是,根據構造方式,我們電路模型中的葉節點建模的是連續均勻分布,因此它們不引入額外的參數。
最終得到的確定性概率電路是一棵樹,其深度與比特串表示中使用的比特數成正比。電路中的每一個和節點(sum node)代表一個比特位的選擇,其權重對應于該選擇的條件概率。
例如,在一個使用3位定點數表示、包含1個整數位且無符號位的系統中,數值 0.5 對應的比特串是010
。它的概率計算方式是基于各個比特位的選擇(b? = 0, b? = 1, b? = 0),并沿著電路路徑進行評估:
其中 表示小數部分的比特數。圖6 展示了電路所反映的決策過程。
ELBO 的計算
確定性概率電路的一個便利特性是:它的熵可以相對于電路邊數在線性時間內計算(Vergari 等,2021),詳見附錄D。
因此,在計算 ELBO 時,我們只需使用蒙特卡洛積分(Monte Carlo integration)來近似公式 (6) 中的期望對數概率項。
為此,我們首先通過逆累積分布函數變換(inverse CDF transform)進行重參數化,而這一變換在確定性概率電路中也可以解析地實現。
具體而言,我們采用如下的 ELBO 重參數化形式:
3. 實驗
我們在第3.1節中首先從一個二維密度估計任務開始,以展示我們的方法在捕捉復雜非高斯分布方面的有效性。
接下來,在第3.2節中,我們通過將 BitVI 應用于貝葉斯神經網絡(NNs),探索在貝葉斯深度學習設置下對高維后驗密度的學習能力,展示了其在預測建模中進行有效不確定性量化的能力。
此外,我們還進行了一系列消融實驗(見附錄F.1),以評估數值精度與模型表達能力之間的權衡,研究比特串長度對性能的影響,以及神經網絡中層次結構的作用。
3.1 二維密度估計
首先,我們在二維的非高斯目標分布任務中展示了我們所提出方法的靈活性。
在圖2中,我們展示了使用4比特和8比特精度的 BitVI 對一些典型基準目標密度函數的近似結果。這些目標分布包括:混合分布、Neal漏斗分布、雙模態高斯分布、環形分布和香蕉形分布。
此外,圖8 展示了對兩種密度函數的詳細對比,表明 BitVI 能夠很好地捕捉整體密度結構以及變量之間的交叉依賴關系,并且隨著比特數的增加,近似質量也隨之提升。
作為參考,圖8中還展示了真實的目標密度以及由全協方差高斯變分推斷(FCGVI)得到的近似結果。
3.2 高維密度估計:貝葉斯神經網絡
接下來,我們嘗試使用 BitVI 來近似貝葉斯神經網絡(NN)參數的后驗密度。
為了簡化實驗,所有神經網絡實驗中我們都采用了相似的網絡結構:所有實驗均使用兩個隱藏層,僅改變每層的神經元數量。此外,我們使用了層歸一化(layer norm;Ba等,2016),以限制權重的縮放范圍。
圖4 展示了一個不確定性量化(uncertainty quantification)的示例。我們使用一個具有 [8,8] 隱藏單元的神經網絡,在“雙月”二分類問題上進行了測試。
預測密度結果顯示,與確定性模型和均場高斯變分推斷(mean-field Gaussian VI)基線方法相比,BitVI 不僅能夠提供具有代表性的不確定性估計,還能生成良好的決策邊界。
為了對神經網絡建模任務進行更定量的評估,我們使用了貝葉斯基準測試套件(Bayesian Benchmarks1),這是一個用于評估機器學習中貝葉斯方法性能的標準社區工具集。
該套件包含了一些常見的評估數據集(通常來自UCI數據庫,Kelly等,2025),并允許在固定評估設置下進行多種方法的比較。
我們在二分類任務中評估了我們的方法,并特別關注了一些小樣本場景下的二分類任務(樣本量范圍為 100 ≤ n ≤ 1000),共涉及25個數據集。
我們遵循該評估套件中的標準設置,對輸入點進行了歸一化處理,并使用了預定義的數據劃分方式。
有關神經網絡結構和評估設置的更多細節,請參見附錄 E.2。
表1展示了 BitVI(使用 2 位、4 位 和 8 位精度)、均場高斯變分推斷(MFVI)以及全協方差高斯變分推斷(FCGVI)的實驗結果。
我們的方法在低比特設置下仍能與標準的 VI 基線方法表現相當,具有競爭力。使用 4 位和 8 位表示的 BitVI 在性能上與 MFVI 和 FCGVI 相當,這表明即使在基于比特串的表示空間中,也可以有效地進行概率推理,而不會顯著損失預測能力。甚至在某些情況下,2 位精度仍然具有可行性。
4. 結論
本文提出了BitVI,一種用于近似貝葉斯推理的新方法,該方法直接在離散比特串表示空間中進行操作。我們展示了:即使在基于比特串的數值表示上,也可以直接進行推理,同時實現有效的近似推斷和不確定性量化。
我們的實驗表明 BitVI 在不同任務中具有良好的靈活性:
在第3.1節中,我們展示了其對復雜非高斯二維密度函數的近似能力;
在第3.2節中,我們在高維的貝葉斯深度學習后驗推理任務中驗證了其有效性,表明它能夠在保持計算效率的同時提供穩健的不確定性估計。
盡管 BitVI 為靈活的變分推斷提供了有前景的方向,但仍存在一些局限性:
由于電路模型對每個參數的比特串進行建模,我們的方法引入了大量需要優化的新參數,這在高維設置下可能帶來進一步的挑戰。為了擴展到高維場景,我們的方法還采用了均場近似(mean-field approximation)來逼近后驗分布,這意味著模型參數之間的依賴關系未被建模。
然而在實際應用中,建模所有參數之間的依賴關系可能是不必要的。因此,一個有前景的未來方向是利用更緊湊的概率電路表示形式,例如 (Peharz 等, 2020) 所提出的方法。
附錄A. 背景與相關工作
連續表示與離散表示之間的關系是計算科學中的一個基本問題。從本質上講,數字計算依賴于離散結構,實數值被編碼為有限長度的比特串(Knuth, 1997,第4章)。浮點運算則是在這種離散框架下對連續值的一種近似,它在保證數值運算效率的同時也引入了固有的精度限制(Sterbenz, 1974,第1章)。
近年來,隨著量化技術和低精度算術的發展,這一基礎性聯系在機器學習領域重新受到關注。雖然這些技術最初主要是出于硬件資源限制的考慮,但它們也帶來了一個新的機遇:如果能夠直接在離散的比特串表示空間中進行推理,就可能在概率建模中實現新的效率提升。
貝葉斯推理提供了一個在不確定性下進行推理的原則性框架,但在大多數現實場景中,精確推理仍然是難以實現的。因此,研究者們發展出了一系列近似推理技術,其中就包括變分推斷(Variational Inference, VI)(Blei等,2017;Jordan等,1999;Wainwright和Jordan,2008)。
VI 將推理問題轉化為一個優化問題:通過擬合一個參數化分布來逼近后驗分布,并最小化其與真實后驗之間的反向KL散度。盡管VI具有良好的可擴展性,但它通常受限于其對連續參數化的依賴,而這種依賴可能會因均場假設或單峰性假設等限制性近似引入數值不穩定性和偏差。
這些局限性在低精度環境下尤為明顯,從而引發了一個關鍵問題:我們是否可以在離散表示空間中直接進行推理?
概率電路(Probabilistic Circuits, PCs)是一種近期提出的、用于研究復雜概率分布的可處理表示的框架(Choi等,2020)。根據PC的結構特性,在電路下某些推理任務可以在多項式時間內完成(即模型復雜度的多項式級),同時保持較高的表達能力。
雖然PC通常用于精確概率推理,但它在近似貝葉斯推理中也有成功應用,例如:
作為替代模型進行編譯推理(Lowd 和 Domingos,2010);
作為結構化離散模型的變分分布(Shih 和 Ermon,2020);
在離散概率程序中使用(Saad 等,2021)。
與我們的研究最相關的是,Garg 等(2024)利用基于比特串表示的PC在概率程序中實現了高效的近似推理。這項工作表明,PC 是一種自然且有前景的表示框架,適用于近似貝葉斯推理和不確定性量化。
附錄B. 動機
給定一個目標密度 p ,我們的目標是找到一個變分近似 q ,使其最小化 p 相對于 q 的散度。
通常情況下,我們將關注的是 q 相對于 p 的反向 Kullback-Leibler(KL)散度(reverse KL divergence),而不是正向KL散度。
此外,我們假設 q 具有參數形式,其參數為 θ,即 qθ 。因此,目標是找到參數 θ,使得:
關鍵在于,當在計算機上計算公式 (5) 或公式 (6) 時,每一個 x 都不可避免地會以離散形式表示。
事實上,每一個實數值都會通過一系列比特串(bitstrings)來表示,并通過一個映射函數 映射到實數軸上,其中 B 表示比特位數,映射方式由所選用的數值系統決定。因此,任何在計算機中表示的分布 p 或 q ,都可以用一個關于比特串的分布來表達。
圖5 展示了使用8位定點表示法對一個實數值進行表示的過程。
附錄C. 技術細節C.1. 定點數表示系統
定點數(Fixed-point)表示系統是將比特串轉換為實數值的一種可能方式。下面我們簡要回顧這種數值表示系統的符號-幅值形式(sign-magnitude form)。
在該表示方式中,最高有效位(most significant bit)用于表示數字的符號:
1 表示負數,
0 表示正數。
其余的比特稱為幅值位(magnitude bits),進一步劃分為:
- 整數位
(integer bits),
- 小數位
(fractional bits)。
顧名思義,整數位用于編碼所表示實數的整數部分,而小數位則用于編碼其小數部分。
具體來說:
整數位表示的是2的非負次冪是否存在;
小數位表示的是2的負次冪是否存在。
例如,在圖5中所示的比特串對應如下計算:
0 × 22 + 1 × 21 + 0 × 2? + 0 × 2?1 + 1 × 2?2 + 1 × 2?3 + 1 × 2?? (不包括符號位)
這代表了該比特串所編碼的數值的整數與小數部分的構成方式。
C.2 概率電路
我們將簡要回顧與本研究相關的一些概率電路(Probabilistic Circuits, PC)的核心概念。如需更詳細的介紹,我們建議讀者參考文獻(Choi 等,2020)。
簡而言之,概率電路是一種表示概率分布的有向無環圖(DAG)。
一個概率電路由以下類型的節點組成:
- 和節點
(Sum nodes,記為 S),
- 積節點
(Product nodes,記為 P),
- 葉節點
(Leaf nodes,記為 L)。
概率電路是一種計算圖:
- 和節點
對其子節點的輸出進行加權求和;
- 積節點
對其子節點的輸出進行乘積運算;
- 葉節點
代表單變量或多變量函數,例如高斯分布。
在對概率電路的結構施加某些約束條件后,可以高效地執行諸如邊緣化(marginalization)等推理任務。下面將簡要回顧其中的一些結構性質。
在本研究中,我們僅考慮同時滿足平滑性(smoothness)和可分解性(decomposability)條件的概率電路,因為這兩個條件對于高效執行常見的推理任務(如密度評估和邊緣化)是必要的。
C.3 多變量比特串表示
到目前為止,我們所構造的誘導變分分布僅定義在實數軸上(即單變量情況)。為了將該方法擴展到多變量情形,我們考慮了兩種方式:
- 均場變分族(mean-field variational family);
- 建模維度之間依賴關系的變分族
為了表示不同維度之間的依賴關系,我們在所有維度的比特狀態上聯合構建一個確定性概率電路(deterministic PC)。
在使用定點數表示系統的情況下,所構建的電路模型通過軸對齊分割(axis-aligned splits)遞歸地將定義域劃分為超矩形(hyper-rectangles),并在構建過程中交替使用不同的維度進行劃分。
需要注意的是,這種構造方式最終形成一棵二叉樹,其葉節點數量為 2B×D ,其中 B 為比特位數,D 為維度數。因此,這種方法更適合于低維或低精度場景。
然而,如果在模型中引入條件獨立性,則可以獲得更緊湊的表示形式(見 Peharz 等,2020;Garg 等,2024)。有關該構造方式的更多細節,請參見附錄 C。
其中,記號略有濫用地表示:
- C?
表示 Sd,?? 的左子節點,對應比特值為 0 的情況;
- C?
表示右子節點,對應比特值為 1 的情況。
由于我們在樹的每一層交替使用不同的維度,因此每一步的決策僅基于當前“選定”的維度。
計算 tree-CDF 變換的逆變換仍然可以高效完成,其時間復雜度為O(B × D),其中 B 是比特位數,D 是維度數。
為了鼓勵 q 在無限精度極限下具有平滑的密度函數,我們在優化電路權重時引入了一種深度正則化方案(depth regularization scheme)(詳見附錄 C.4 的進一步說明)。
如正文所述,在多變量分布的情況下,我們構建了一個電路模型,用于表示定義在超矩形(hyper-rectangles)上的分布。
設 Ω 表示該分布的定義域,我們遞歸地構造一個將定義域劃分為可測子集的二進制劃分(dyadic partition)。
這個過程通過在樹的每一層選擇一個分割維度,并根據數值系統的表示方式對超矩形進行劃分來實現。例如,在定點數系統中,劃分位置位于中間。
在下一層,我們從剩余未被劃分的維度中選擇一個新的分割維度,并相應地繼續劃分。
我們確保所有維度都被劃分過之后,才重新開始新一輪的劃分。
當每個維度都已經被劃分了 B 次之后,構建過程結束,其中 B 是數值系統中使用的比特位數。
圖7 展示了輸入定義域 Ω 被遞歸劃分為多個子定義域(即超矩形)的過程。
C.4 深度正則化
附錄 F. 額外實驗結果
以下部分包含額外的實驗結果。
F.1. 消融實驗(Ablation Studies)
目標分布復雜度的增加
我們考慮了一個消融實驗,用于控制目標分布的復雜度。為此,我們構造了一個等距高斯混合分布,并在每個高斯具有不同方差的情況下,評估 BitVI 在使用不同位數時的熵表現。
圖 9 顯示了 BitVI(黑色曲線)在使用 16 位時對復雜度不斷增加的目標分布(灰色曲線)的擬合結果,以及在不同位數下 BitVI 的熵變化。下方的熵圖顯示了表示每個目標所需位數的截斷點,這表明 BitVI 自然地表現出一種簡潔(parsimonious)的行為。
F.2. 模型復雜度與比特串深度之間的權衡
在神經網絡(NN)應用中,一個有趣的問題是:我們是否真的需要在模型權重的數值精度上做到非常精細?近年來的大規模模型訓練與推理研究表明,比起數值精度,模型更受益于更多的參數數量,從而獲得更高的靈活性。
圍繞這個問題,我們研究了模型是否能從概率處理中獲得更高數值粒度(numerical granularity)的好處。
在表 2 中,我們同時改變了神經網絡的復雜度(即兩個隱藏層中的單元數)和比特串的深度。我們考慮了 2 到 12 位的模型(僅使用小數位)。在“two moons”數據集上的負對數預測密度(NLPD,越小越好)結果表明,即使是低比特深度的模型也能表現良好,而表達能力的主要決定因素是神經網絡中的單元數量。
在附錄 F 中,我們還提供了關于準確率(accuracy)和期望校準誤差(ECE)的類似表格。
比特串是否捕捉到了神經網絡中的層次結構?
最后,我們使用一個神經網絡模型來研究 BitVI 所捕捉到的層次結構。我們從在 Banana 二分類數據集上訓練得到的一個 10 位 BitVI 神經網絡模型出發,逐步降低訓練后模型的小數精度,去掉模型中更細粒度的層級。
圖 10 展示了 10、8、6、4 和 2 位模型的結果(每個模型使用 2 位整數位,除了 2 位模型)。即使 4 位模型(2 位整數 + 1 位小數)也能很好地捕捉整體結構,而 2 位模型(沒有整數位;只有符號位和一位小數)則表現較差。
原文鏈接:https://openreview.net/pdf?id=SaidYid1Mq
特別聲明:以上內容(如有圖片或視頻亦包括在內)為自媒體平臺“網易號”用戶上傳并發布,本平臺僅提供信息存儲服務。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.