Bit Blasting Probabilistic Programs
新的概率編程語言HyBit
https://github.com/Tractables/Dice.jl/tree/hybit
https://arxiv.org/pdf/2312.05706v2
摘要
概率編程語言(Probabilistic Programming Languages, PPLs)是一種用于創建和推理概率模型的富有表現力的工具。不幸的是,當今的PPL對混合概率程序(即同時涉及連續和離散結構的概率程序)支持不夠良好。在本文中,我們開發了一種新的近似推理算法,用于混合概率程序,該算法首先對連續分布進行離散化,然后在得到的程序上執行離散推理。我們的核心創新是一種稱為“位爆破”(bit blasting)的離散化形式,它使用數字的二進制表示,使得一個包含2^個離散點的域可以簡潔地表示為一個由poly()個布爾隨機變量構成的離散概率程序。令人驚訝的是,我們證明了許多常見的連續分布可以通過“位爆破”進行離散化,而不會相對于顯式離散化造成任何精度損失,并且還支持高效的概率推理。我們構建了一個名為HyBit的概率編程系統,專門用于處理混合程序,該系統采用位爆破方法,隨后進行離散概率推理。我們通過實驗驗證了我們的方法相比于現有的基于采樣的和符號推理方法的優勢。
CCS概念:? 計算數學 → 概率表示;概率推理問題
附加關鍵詞和短語:離散化、位爆破、概率推理
ACM參考格式:
Poorva Garg, Steven Holtzen, Guy Van den Broeck 和 Todd Millstein. 2024. Bit Blasting Probabilistic Programs. Proc. ACM Program. Lang. 8, PLDI, Article 182 (June 2024), 24 pages. https://doi.org/10.1145/3656412
1 引言
概率編程語言(PPLs)是一種用于創建和推理概率模型的富有表現力的工具。許多這樣的模型本質上是混合的,即同時涉及連續結構(例如高斯分布)和離散結構(例如伯努利隨機變量、if語句及其他控制流)。例如,在醫療診斷、基因表達和信息物理系統等應用中就會出現混合模型 [Chen et al. 2020; Lee and Seshia 2017]。
然而,今天的PPL并不能很好地支持混合程序。概率編程語言中的主要分析任務是概率推理,即計算某個事件根據程序定義的分布發生的概率。現有的推理算法通常采用采樣方法來進行近似推理。一些方法,特別是哈密頓蒙特卡洛方法(Hamiltonian Monte Carlo),被用于Pyro和Stan等PPL中 [Bingham et al. 2019; Gorinova et al. 2021],但它們不支持離散隨機變量,而是要求將這些變量(手動或自動)邊緣化出去。然而,這種方法在許多情況下會導致指數級爆炸,尤其是在離散變量數量較多時。其他基于采樣的方法具有通用性,因此可以處理離散性,如重要性采樣、馬爾可夫鏈蒙特卡洛和序列蒙特卡洛等 [Koller and Friedman 2009]。然而,這些算法在面對由于離散性產生的多峰分布 [Yao et al. 2021] 以及那些以低概率事件為條件的程序時,常常表現不佳。
在本文中,我們提出了一種通過離散化來處理混合概率程序的新推理算法:我們將混合程序中的連續分布轉換為離散分布。這產生了一個完全離散的概率程序,可以在其上使用現有的離散推理算法。離散化將連續分布近似為一系列區間,每個區間對應值落在該區間內的概率。先前的研究 [Albarghouthi et al. 2017; Beutner et al. 2022; Claret et al. 2013; Huang et al. 2021] 中已經使用過各種形式的離散化方法,但它們的規模都與區間的數量呈線性關系。這就帶來了一個明顯的權衡:需要足夠多的小區間以避免過多的精度損失,但隨著區間數量的增加,推理的成本迅速變得不可承受。
我們引入了一種新的離散化方法,稱之為“位爆破”(bit blasting),這一術語借鑒自驗證領域中的同名技術 [Bruttomesso and Sharygina 2009]。位爆破離散化的一個關鍵特性是,它僅需使用 poly(log) 個布爾隨機變量即可表示 個區間的離散化。這是通過使用數字的二進制表示,并將離散化表示為基于這種二進制表示的離散概率程序來實現的。乍一看,這種簡潔的表示方式似乎會丟失太多精度,從而無法成為一個可行的策略,但我們通過理論和實驗證明事實并非如此。
首先,我們證明了一大類常見的連續密度函數可以被“位爆破”安全地處理,也就是說,其精度與樸素離散化相比沒有任何損失。表1列出了屬于這一類的示例分布;我們將整個這類分布稱為混合伽馬分布(mixed-gamma distributions)。
例如,考慮將0到1之間的連續均勻分布進行離散化。樸素地將其離散化為232個區間需要枚舉232個值。相反,這個分布可以以二進制形式表示為一個由32個flip(0.5)形式的伯努利隨機變量組成的元組,即擲硬幣,其結果等概率為0或1。這一觀察并不新穎,類似的結果也適用于指數分布 [Marsaglia 1971]。然而,其他混合伽馬密度函數的位爆破離散化是新穎的。此外,與均勻分布和指數分布不同,這些離散化并不是簡單地定義為一組獨立的伯努利隨機變量,而是需要基于這些變量構建完整的離散概率程序。
一種簡潔的表示方式并不一定意味著推理效率高,因為一般情況下推理本身就很困難。作為我們的第二個貢獻,我們證明了位爆破后的混合伽馬分布在不僅是正確且簡潔的前提下,還支持在位精度數量上的多項式時間推理。具體來說,我們證明了使用知識編譯(knowledge compilation)方法進行離散概率推理 [Chavira and Darwiche 2008; Chavira et al. 2006; De Raedt et al. 2007; Fierens et al. 2015; Holtzen et al. 2020],該方法將推理歸約為布爾公式上的加權模型計數(weighted model counting),對于位爆破后的混合伽馬分布具有上述性質。因此,對混合伽馬分布進行位爆破后,再通過知識編譯進行離散推理的過程,其運行時間被保證為離散化所用位寬的多項式級別。
第三,我們利用上述理論成果設計了一個新的概率編程系統HyBit,它通過位爆破實現對混合概率程序的非隨機近似推理。HyBit是一種離散概率語言,支持定點二進制數及其相關的算術運算,并作為嵌入式的領域特定語言在Julia中實現 [Bezanson et al. 2017]。HyBit的API允許用戶生成任意連續分布的位爆破定點近似。混合伽馬分布可以通過其正確的位爆破離散分布來表示。對于其他分布,API允許用戶使用分段離散近似,其中每一段本身就是一個位爆破后的混合伽馬分布。
HyBit利用知識編譯對給定的離散概率程序執行精確推理。在最壞情況下,通過對混合概率程序進行位爆破推理的時間復雜度可能是位數的指數級,但我們通過實驗驗證了我們方法的優勢。我們展示了HyBit在一套全面的混合概率程序基準測試中,表現優于現有的基于采樣和符號推理的方法。
總體而言,本文提出了以下貢獻:
我們在第2節中闡述了混合概率程序推理中的挑戰。
我們在第3節中提出了一種新的離散化方法,稱為位爆破,其特點是離散區間數量的簡潔性。
我們提出了一類連續分布,即混合伽馬分布,在這類分布中存在正確的位爆破表示。我們在第3節中形式化了這種構造并證明了其性質。我們進一步證明,基于知識編譯的推理在這些分布上能夠隨位寬呈多項式級擴展。
我們在第4節中描述了HyBit概率編程語言及其通過位爆破實現的新推理算法。
在第5節中,我們將HyBit與其他PPL在來自現有文獻的基準測試中進行了實證比較。我們還分析了HyBit在其超參數(即位數和分段數)下的行為表現。
HyBit的項目地址為:https://github.com/Tractables/Dice.jl/tree/hybit 。
本文的完整版本可在arXiv上獲取,編號為[Garg et al. 2023],其中包含完整的證明過程。
2 動機示例
本節通過三個示例說明對混合概率程序進行推理的挑戰。首先,我們展示一個來自計算生物學的例子,其中包含固有的邏輯結構。接下來,我們展示文獻中的一個例子,由于離散控制流的存在,其后驗分布呈現多模態。最后,我們通過共軛高斯分布展示低概率觀測值的一個例子。隨后,我們研究了包括 HyBit 在內的各種推理算法的表現。這些示例展示了 HyBit 相對于其他近似推理算法的優勢。
我們在第5節中提供了與多個近似和精確基線方法的詳細比較。
2.1 邏輯結構
我們展示了一個來自計算生物學的簡化示例,該示例將基因表達與血糖水平聯系起來。圖1展示了對應的概率程序,其中的任務是:在已知患者的血糖水平的前提下,更新關于該患者體內某個基因出現概率的信念。
圖1中概率程序的前四行使用beta分布作為一般人群中個基因出現的先驗概率。語法flip()
表示一個成功概率為的伯努利隨機變量。在第5行,程序使用reduce(|, gene)
語法來表示表達式????? gene[i]
,換句話說,如果至少有一個基因被表達,則認為該患者患有糖尿病。
接下來是關于患者血糖水平的多次測量讀數。對于每一次讀數,程序首先根據患者是否患有糖尿病定義了一個隨機變量來表示血糖水平。然后我們使用observe(y, v)
語法來對隨機變量取值為進行條件約束——在程序中,這用于對患者實際的血糖讀數進行建模。最后,在第12行,程序查詢了第一個基因出現的后驗分布的期望值。
圖2展示了使用不同的推理算法對該程序進行推理的結果,在20分鐘的超時限制下,隨著基因數量(T)的增加,各算法的表現情況。
Stan 使用的是哈密頓蒙特卡洛方法(Hamiltonian Monte Carlo),這種方法不直接支持離散隨機變量。取而代之的是,需要通過手動或自動的方式將這些變量邊緣化出去,例如使用變量消除(variable elimination)[Gorinova 等 2020]。如圖所示,當基因數量超過15個時,Stan 就會超時——因為其計算復雜度隨 T 呈指數級增長。
同樣的指數爆炸問題也困擾著 GuBPI [Beutner 等 2022],該方法采用符號求值與離散化相結合的方式來計算上下界。正如圖中所示,通用采樣方法如帶有 Metropolis Hastings 核的馬爾可夫鏈蒙特卡洛(WebPPL MH)和序列蒙特卡洛(WebPPL SMC)能夠擴展規模并提供合理的準確性。
精確推理策略 Psi [Gehr 等 2016] 在程序設計避免產生大規模離散狀態空間的前提下,也能很好地應對離散隨機變量數量的增長。AQUA 對混合程序進行離散化處理 [Huang 等 2021],但無法支持本例中的這個程序。
我們的系統與方法 HyBit 能夠擴展到 50 個基因,并且在所有近似推理算法中具有最小的絕對誤差。當該程序以 HyBit 編寫時,它通過其離散的位級抽象進行表示。用戶在編寫程序時,可以使用 HyBit API 將所有的連續分布(具體來說是圖1中的第2、6、9行)替換為其位爆破后的離散近似版本。
結果是,我們現在得到了一個離散程序,其中包含對布爾值和定點數的分布。請注意,現在第8行的 observe
條件是基于這樣一個事實:blood_sugar1
的離散分布取值落在對應數值79的區間內。
在本次實驗中,我們使用了 25位 的位寬——每個連續分布被離散化為一個由25位組成的元組所表示的程序,這些位被解釋為一個定點數。此外,我們使用了 4096個分段 來近似這些連續分布,每個分段本身是一個位爆破后的指數分布。
如果采用樸素的離散化方式,使用25位將是非常緩慢以至于不可行的,因為它會產生 22? 個區間,即約1.34億(134M)個區間。然而,我們通過位爆破后的程序僅使用了 53K 次擲硬幣 (即布爾隨機變量)來表示它們。
此外,基于知識編譯的推理方法 [Fierens 等 2015; Holtzen 等 2020] 自動識別并利用了程序邏輯結構中的條件獨立性,從而幫助推理過程實現良好的擴展性。
有關此實驗的更多細節,請參見附錄。
2.2 處理多模態性
本節展示了一個多模態分布的例子,用以突出混合概率程序推理中的另一個挑戰。多模態分布在多個高峰之間被低概率區域分隔,這類分布在諸如傳感器網絡定位、宇宙學等多個應用領域中非常常見 [Shaw 等 2007; Tak 等 2018]。我們改編了文獻中的一個例子 [Yao 等 2021],如圖3所示。
圖3中所示的概率程序對現有的概率推理方法來說具有很大挑戰性。第3行的datapts
包含九個數據項,其中三分之二為5,其余為?5。這導致了(?, ?)的后驗分布圍繞(5, ?5)和(?5, 5)呈現雙峰分布。
隨著數據點數量按相同比例增加,(?, ?)的后驗會收斂到(5, ?5)。然而,在數據點數量有限的情況下,?的后驗在5和?5附近呈現雙峰分布。
多個模態的存在給基于采樣的算法帶來了挑戰,因為它們容易陷入某一個模態中無法跳脫。具體來說,使用Metropolis Hastings核的WebPPL MCMC算法,以及使用HMC的Stan和WebPPL,最終都會任意地停留在其中一個模態中,并未能探索另一個模態。
圖4a和圖4b分別展示了使用WebPPL MCMC和Stan HMC得到的結果,其中兩次不同的運行分別被困在了不同的模態中。另一方面,HyBit對其離散抽象執行精確推理,因此能夠全局探索整個分布,從而識別出兩個模態。序列蒙特卡洛(SMC)(圖4c)、Psi 和 GuBPI 同樣能夠進行全局探索,因此也能找到這兩個模態。最后,為了應對直接離散化帶來的計算挑戰,AQUA通過調整其離散化區間,專注于高概率區域,結果只識別出了更高概率的那個模態(圖4d)。
2.3 處理低概率觀測
本節展示了一個具有低概率觀測值的共軛高斯示例(見圖5)。在圖5中,在第2行和第3行對低概率數據進行條件化之后,程序查詢了隨機變量mu
的后驗分布。
為什么低概率觀測值難以處理?直觀上來說,通用的采樣算法從先驗分布開始采樣,但在尋找具有顯著權重的樣本時會遇到困難。只有在進行大量采樣之后,這些算法才能逐漸接近真實的后驗分布。
另一方面,HyBit 通過對位爆破抽象執行精確推理,能夠全局探索后驗分布的定義域,因此不受這一問題的影響。
圖6繪制了真實的先驗分布和后驗分布,以及來自不同推理算法的結果。
對于基于采樣的算法——使用Metropolis Hastings核的MCMC 和 SMC,我們運行了相應的WebPPL程序 [Goodman and Stuhlmüller 2014],并在之后獲得了1000個樣本并進行了繪圖。重要性采樣(importance sampling)算法未能為此程序獲得任何具有非可忽略權重的樣本。
這些采樣器正在將后驗分布向真實后驗靠近,但需要更多的樣本才能實現這一點。即使分別采樣了約1600萬次和65,000次之后,MCMC和SMC所獲得樣本的期望值仍然存在絕對誤差 分別為 0.549798 和 1.520776 。
GuBPI 報告了每個區間上概率的上下界,并產生了 2.33 的絕對誤差。
另一方面,HyBit 所得到的后驗分布與真實分布完全重合。
Stan HMC 在處理低概率觀測方面表現良好,并獲得了很高的精度。
Psi 同樣得到了后驗分布的精確符號表達式。
最后,AQUA 所報告的后驗均值的誤差為 5.66 ,因為它未能對的先驗做出任何更新。
3 位爆破:核心見解
為了在混合概率程序的離散結構上實現推理的擴展性,我們需要一種將離散性作為頭等公民對待,并能有效消除連續結構的算法。本節定義了“位爆破”(bit blasting)的語義概念,并將其設定為具有優良性質的離散化的一種特殊情況。接著,我們為常見類別的連續分布提供了位爆破函數。我們所提供的離散化技術具備正確性 (精度可達位)、簡潔性 ,并適用于高效的推理。
3.1 離散化與位爆破
在概率論的標準術語中 [Rosenthal 2006],一個概率空間 (Ω, Σ, ) 由樣本空間 Ω、Ω 上的 -代數 Σ,以及定義在 Σ 上的概率測度 組成。
廣義而言,一個離散化函數 的輸入是一個這樣的概率空間 (Ω, Σ, ),輸出則是一個離散概率空間(Ω, Σ, ),其中 Ω 是一個可數集合。
我們將研究一種更具體的離散化概念:該離散化以有限區間上的連續分布 為輸入,并輸出一個在 2? 個點上的離散分布 ,其中 為位數。
形式上,設 [, ) 是一個區間,其中 , ∈ ?。對于輸入概率空間,我們使用 B([, )) 來表示區間 [, ) 的子集所構成的 Borel -代數 [Rosenthal 2006]。
對于輸出概率空間,我們用 P() 表示集合 的冪集(即其對應的 -代數)。
此外,我們假設樣本空間的離散化方式如下所示。
在我們定義一個 位的位爆破函數 之前,我們需要先確定一個通用的離散概率分布表示方式。為此,我們定義了一個稱為離散概率閉包 (discrete probabilistic closure)的概念,類似于概率圖靈機 [Arora and Barak 2006]。
每個概率閉包是一個從一組有偏硬幣擲(biased coin flips)到某個離散集合的確定性函數 。該函數通過與硬幣擲結果相關的概率,在其輸出上誘導出一個概率分布。此外,它還包含一個接受布爾公式 (accepting Boolean formula),用于處理觀測,并限制輸入硬幣擲所能取的值的集合。
下面我們將給出其形式化定義:
請注意,正如示例3所解釋的那樣,任何在 2? 個值上的離散分布 都可以表示為一個離散概率閉包 (, , ),其中 || = 2? ? 1。
Dice [Holtzen 等 2020] 和 Problog [Fierens 等 2015] 就是直接符合離散概率閉包這一范式的概率編程語言(PPL)的例子。
我們希望離散概率閉包能夠更加簡潔 ——其大小應為精度位數 的多項式級別。為此,我們定義了位爆破函數 。
定義5(位位爆破函數) :一個 位位爆破函數 [.]? 是一種 位的離散化函數,它輸出一個離散概率閉包 (, , ),該閉包所使用的布爾隨機變量數量是位數 的多項式級別,即 || ∈ (poly())。
根據定義3、定義4和定義5可知,對于任意整數 > 0,一個 位位爆破函數對于給定的概率空間 ([, ), B([, )), ) 是正確 (sound)的,當且僅當:
我們想指出的是,bitblast_unif
僅使用了 2次硬幣擲 就表示了一個在 4個值 上的分布。它可以推廣為使用 次硬幣擲 來表示一個在 2? 個值 上的均勻分布。
相比之下,na?ve_unif
使用了 3次硬幣擲 ,并會推廣為使用 2? ?1 次硬幣擲 。
因此,bitblast_unif
是一個適用于 位位爆破函數(-bit blasting function)的有效輸出。
在下一節中,我們將詳細介紹指數分布和混合伽馬分布的 位位爆破函數的具體實現。
3.2 具體的位爆破函數:預備知識
接下來,我們的目標是為混合伽馬分布 (mixed-gamma distributions)提供一個具體的、滿足正確性要求的位爆破函數實例 。混合伽馬分布的概率密度函數定義如下。
為了記號上的方便,我們將連續分布限制在單位區間 內,從而得到定義在 位單位區間 上的離散分布。隨后我們會將我們的方法推廣到任意有限區間,以構建基于位爆破的概率編程系統 HyBit 。
為了描述我們對位爆破函數的構造,我們使用了 Dice [Holtzen 等 2020]。Dice 已經能將其程序編譯為帶權布爾公式 (通過 ? 判斷),這些公式符合離散概率閉包 的定義。3
這使得我們只需定義一個從概率密度函數到 Dice 程序的 ?? 判斷,即可指定一個位爆破函數。
Dice 還定義了一個分布語義函數 J.K?:p → V → [0, 1],它以一個 Dice 程序 p 作為輸入,并輸出一個歸一化的概率分布(表示為從值集合 V 到概率的一個函數)。我們在后續使用函數 J.K? 來論證我們構造的正確性。
有關 Dice 的語法和語義的更多細節,請參見附錄。
混合伽馬密度函數的編譯判斷形式為 Υ ?? p ,其中 p 是 Dice 程序。
我們進一步給出以下定義:
- 密度函數與 Dice 程序之間的 等價性
(-equivalence)
- Dice 程序的 簡潔性
(-succinctness)
我們定義 -簡潔性 (-succinctness),要求 Dice 程序 p 中使用的硬幣擲次數與位數 成線性關系。請注意,-簡潔性 所施加的條件比 位位爆破函數 所需的條件更為嚴格(后者只要求使用 poly() 次硬幣擲)。這意味著,如果我們為某個混合伽馬密度函數建立了一個滿足 -簡潔性 的判斷,那么我們就可以為該分布構造一個合法的
3.3.1 指數分布 ?,
讓我們首先考慮均勻分布 (?,?),它是指數分布的一個特例。
如果我們使用 位 對均勻分布進行位爆破,并將其劃分為 2? 個區間,那么我們會得到一個在 [0, 1]? 上定義的離散分布 ?,它包含 2? 個離散點,每個點的概率為 1 / 2?。
一種直接的離散化策略是枚舉這 2? 個值,這需要使用 2? ?1 次硬幣擲。
但我們可以用更高效的方式實現同樣的效果:使用一個由 位 構成的元組,其中每一位都是一個無偏的硬幣擲 flip(0.5)。
之所以可以這樣對均勻分布進行位爆破,是因為二進制數字之間具有獨立性 。這一策略同樣可以推廣到一般的指數分布。這一觀點早在統計學的經典論文中就被提出 [Marsaglia 1971]。我們通過以下規則形式化這一思想。
我們在附錄中提供了上述引理的詳細證明。在本文的其余部分,我們將指數分布視為構建其他分布的基本元素,因為只有它們具備二進制位之間相互獨立 這一特性。然而,正如我們接下來將展示的那樣,對于其他分布,構造正確的位爆破函數 仍然是可能的。
3.3.2 伽馬分布 ?,
為了為 ?, 構造一個正確的位爆破函數,我們提出一個關鍵的數學洞察。
考慮圖8a中的程序。連續隨機變量 X 和 Y 分別服從均勻分布(?,?)和指數分布(?,)。該程序返回在條件 Y < X 下 X 的新分布。
結果表明,這個后驗分布是一個特定的伽馬分布 ?,。
下面我們展示相應的推導過程,其中 pdf 表示概率密度函數。
如果我們使用 位 對圖8a中的程序進行離散化,會發生什么?
我們會得到圖8b中的程序,其中每個連續隨機變量都被其對應的位爆破版本 所替代(例如 X 被替換為 X?,依此類推)。我們已經知道,對于均勻分布 和指數分布 ,存在滿足 等價性 的 Dice 程序。但問題是:其他構造 (如條件語句)會怎樣呢?
正如圖8d所示,observe(Y? < X?)
相較于其連續版本 observe(Y < X)
會產生誤差。好消息是,我們可以通過以下公式來量化并處理這種誤差 :
規則 Expo1 和 Trans-expo1zero (見附錄)捕捉了上述直覺。
其中,unifObs(y, b) = p
是一個輔助判斷,程序 p
構造了一個均勻分布 ,并通過 observe
對其進行條件約束,使其取值小于 y 。
3.3.3 廣義伽馬分布 ?,
前一小節展示了通過對不等式 (Y < X) 進行條件約束,如何在 ?, 的基礎上引入一個線性因子,從而得到 ?,。請注意,無論隨機變量 X 的初始概率密度函數是什么形式,對 (Y < X) 進行條件約束都會引入一個線性因子 。也就是說,如果 X 的初始概率密度為 (),那么如下概率程序的輸出密度將是 ()。這意味著:如果我們能夠對 () 進行位爆破,我們也就能夠對 () 進行位爆破。但我們仍需考慮由 observe(Y? < X?)
所帶來的誤差,即 Pr(? | ? == ?, < )。結果表明,這個修正項是一個伽馬分布的混合分布 ,而這類分布也可以被正確地進行位爆破 。
我們在附錄中提供了相應的判斷規則以及以下引理的證明。
3.3.4 伽馬分布的混合形式 ∑? ???,?
由于廣義伽馬密度函數 ?, 可以被進行位爆破,因此混合伽馬密度函數 也可以被正確地進行位爆破。
我們對每一個單獨的廣義伽馬密度進行位爆破,然后使用 if-then-else
結構來構建它們的混合形式,如下所示。
之前的研究 [Holtzen 等 2020] 定義了判斷 ?,它以 Dice 程序 p 作為輸入,并輸出一個帶權布爾公式 (, , ),該公式與離散概率閉包 (定義4)的定義一致。
并且由于根據定理8 ,對于所有混合伽馬密度函數,判斷 ?? 是 -簡潔的 (-succinct),因此 ? 總是輸出一個使用 poly() 次硬幣擲的 。
因此,復合判斷 ? ? ?? 構成了一個合法的 位位爆破函數 。
此外,之前的研究 [Holtzen 等 2020] 還證明了將程序編譯為帶權布爾公式的正確性,即其結果與 Dice 程序的語義是一致的。
結合這一結論和定理7 ,我們可以得出:復合判斷 ? ? ?? 是一個正確的 位離散化函數 。
詳細的證明可以在附錄中找到。
3.3.5 示例:拉普拉斯分布
前面的章節描述了當混合伽馬密度被限制在單位區間時,如何對其進行位爆破。那么,對于那些被平移 或縮放 到其他有限區間上的分布,又該如何進行位爆破呢?
我們通過拉普拉斯分布 (Laplace distribution)來解釋這一點。
拉普拉斯分布有兩個參數:
位置參數(location)
尺度參數(scale)
其概率密度函數如下所示,其中 ∈ ?:
我們考慮在區間 [ ? , + ) 上被截斷的拉普拉斯分布。我們假設 是一個合適的 2 的冪次 ,這樣與 的乘法運算(記作 × p)可以簡化為小數點移位操作 ,同時假設 是一個可以用 位 精確表示的數,以便對 p 進行精確移位。
首先,我們生成在寬度為 而非 1 的區間上的縮放后的指數分布:
3.4 位爆破如何助力推理?
我們已經展示了針對混合伽馬分布 的正確位爆破函數 。但這種位爆破方式如何幫助對包含這些分布的概率程序進行推理呢?
為了回答這個問題,我們聚焦于一種特定的推理策略——知識編譯 (knowledge compilation)。
我們首先介紹有關知識編譯的一些必要預備知識,然后論證通過 ?? 所獲得的程序為何在知識編譯中具有高效性。
基于知識編譯的方法 [Fierens 等 2015; Holtzen 等 2020] 用于精確的離散概率推理 ,其核心思想是將離散概率程序編譯為帶權布爾公式 ,并使用有序二元決策圖 (Ordered Binary Decision Diagrams, OBDDs)來表示這些公式。
當程序返回一個 單個布爾隨機變量 時,對應的 OBDD 是 單根 的;
當程序返回一個 布爾隨機變量的元組 時,對應的 OBDD 是 多根 的。
通過為程序中的所有硬幣擲(即帶有權重的有偏硬幣擲)設定一個取值,并沿著這些取值在 OBDD 中進行遍歷(實線表示 true,虛線表示 false),我們可以到達對應每一位輸出值的終端節點。
加權模型計數 (Weighted Model Counting)操作可以計算每個輸出位到達“1 終端”的概率。這是一個動態規劃算法,其運行時間與 OBDD 的大小成線性關系。
一個布爾公式 對應的 OBDD 大小,記作 OBDD(),是指該 OBDD 中的節點數量。
因此,如果我們能為某個分布構造出更小的 OBDD 表示,就可以更高效地將其與其他結構組合進一個離散概率程序中。
接下來我們正式討論并證明:通過判斷 ?? 所獲得的每一個程序都可以被編譯為一個帶權布爾公式,進而被編譯為一個多根 OBDD ,其大小隨著位數呈線性增長 ,而不是最壞情況下的指數增長。
回顧一下,Dice 程序會被編譯為一個帶權布爾公式 (, , ):
是一個(或一組)布爾公式,對應程序的返回值;
是一個接受布爾公式,用于編碼觀測條件;
是一個權重函數,表示各個硬幣擲的概率。
定理10 :對于所有 Υ, p, , , ,存在某個常數 ,使得對任意位數 ,如果 Υ ?? p 且 p ? (, , ),那么存在一個布爾隨機變量在 中的變量順序 Π,使得:
OBDD() + OBDD() ≤
我們現在為上述定理的證明提供直覺解釋。
請注意,在通過判斷 ?? 所獲得的程序中,只有兩個構造依賴于位數 :
- 指數分布 ?, 的構造
- 通過對不等式的條件約束
(即通過
unifObs(y, b)
對指數分布與均勻分布之間的關系進行建模)
接下來我們將說明這些構造所對應的 OBDD 大小如何隨著位數 線性增長 。
3.4.1 指數分布
在圖9中,我們對一個指數分布進行了 3位爆破 ,也就是說,我們在8個值上得到了一個離散的指數分布:{0, 0.125, 0.25, ..., 0.875}
。
圖中展示了一個三根節點的 OBDD ,其中每個根節點分別標記為 ?、? 和 ?,代表返回的3位值中的每一位。
假設我們將對應節點 ? 的硬幣擲設為 true,那么對于 ? 來說,我們會到達終端 1,并將其值設為 1。
加權模型計數 (WMC)操作會計算 ? 取值為 1 的概率為 6.14×10??,因為節點 ? 為真的概率是 6.14×10??。
同樣地,WMC 也可以用于該 OBDD 的其他根節點。
由于每一位只需要一個 OBDD 節點,因此整體的 OBDD 大小隨位數 線性增長 。
又因為 WMC 的運行時間與 OBDD 的大小成線性關系,所以對指數分布的推理時間復雜度為 O()。
附錄中還給出了一個關于均勻分布的 OBDD 示例。
對于所有通過判斷 ?? 得到的程序 p,如果 p ? (, , ),那么:
是表示程序返回值的一組布爾公式。
我們論證說,p 的返回值始終是一個指數分布的混合體 ,因此其對應的 OBDD 大小也與位數呈線性關系。
3.4.2 對指數分布與均勻分布之間的不等式進行條件約束
判斷 ?? 的規則使用了輔助判斷 unifObs(y, b) = p
,用于對均勻分布與指數分布的二進制表示之間的不等式 進行條件約束。
由于唯一依賴于位數的構造(即指數分布和不等式)都隨著位數呈線性增長,因此定理10在直覺上是成立的。我們在附錄中提供了形式化的證明。
4 HyBit:一個概率編程系統
前一節介紹了如何對混合伽馬分布 進行位爆破 (bit blasting)。我們進一步利用這一技術構建了一個用于混合概率程序 的概率編程系統 HyBit 。
本節將描述其語法與實現 ,并詳細闡述兩個重要方面:
- 連續分布的分段近似方法
- 采用二進制表示的優勢
我們構建了一個概率編程系統 HyBit ,它圍繞混合伽馬密度函數的正確位爆破 (sound bit blasting)以及其他連續分布的近似位爆破 展開。
HyBit 是作為一門淺嵌入式領域特定語言 (shallow embedded DSL)在 Julia [Bezanson 等 2017] 中實現的。
圖10給出了 HyBit 表達式的核心語法。它支持以下功能:
對布爾值的概率分布建模:
flip
(即伯努利分布)對定點數的概率分布建模:
general_gamma
和bitblast
布爾運算:?, ∧, ∨
算術運算:+, ?, *, /, %, <, ==
概率條件建模中的硬性觀測:
observe
對于圖10中列出的所有構造,HyBit 會執行一種非標準解釋執行 (non-standard execution),并將其編譯為 OBDD (有序二元決策圖),以進行概率推理。
由于 HyBit 是作為 Julia 的一個庫實現的,因此程序員還可以使用 Julia 提供的語言結構,例如(有界)循環、元組和函數。
舉個例子,可以在循環體中使用 HyBit 構造配合 Julia 的 for
循環來構建概率模型。
HyBit 作為一個開源項目提供,并附帶了豐富的示例集。
圖11詳細展示了 HyBit 的 API。
DistFix{W, F}
表示一個定點數類型,總位寬為 W,其中小數點后占 F 位。函數
general_gamma
可對指定的廣義伽馬密度函數進行正確的位爆破 ,將其映射到給定的定點數精度 W 和 F 上。通過在廣義伽馬密度上使用
if-then-else
結構,可以實現對混合伽馬密度函數 的正確位爆破。函數
bitblast
用于對任意連續分布進行分段近似 的位爆破,用戶可通過參數 W、F 和 pieces 來指定使用的位數和分段數量。API 還允許用戶選擇用于分段近似的離散分布類型:線性分布 或 指數分布 。
線性分段的參數(斜率)和指數分段的參數()會自動選擇,使得第一個區間與最后一個區間的概率比值 與樸素離散化保持一致。
最后,API 還提供了用于查詢隨機變量的概率分布 、期望值 和 方差 的函數。
接下來的小節將更詳細地描述分段近似方法 ,以及期望值與方差的計算方式 。
4.2 分段近似
盡管混合伽馬分布能夠涵蓋許多常見的自然分布,但仍有一些常見分布不在其中,例如高斯分布 。
目前仍是一個開放問題:高斯分布是否可以被正確地進行位爆破(更不用說編譯成緊湊的 OBDD)。對于這類分布,我們可以使用分段近似 (piece-wise approximation)方法來處理。
設 是定義在區間 [, ) 上的一個任意連續概率分布。
要使用具有 個分段的分布對 進行位爆破,我們通過一個由 個離散概率分布 組成的混合分布來近似 ,這些分布分別對應互不重疊的子區間。
對于每一個分段,我們構造一個經過平移和縮放的位爆破混合伽馬密度函數 ,然后將它們組合成一個混合分布。
請注意,由于每個分段使用了 O() 次硬幣擲,因此具有 個分段的分段分布總共使用 O() 次硬幣擲。
第5節展示了這種策略在實驗上的優勢。
這種使用線性分段 或指數分段 的近似方法,可以很方便地通過 HyBit 中提供的 bitblast
API 實現(見圖11)。
圖12 展示了使用 2、4、8 和 16 個分段 對高斯分布進行位爆破的結果,其中每個分段都是一個位爆破后的指數分布。這為用戶提供了在精度與性能之間的傳統權衡 。
我們將在第5.2節中對此進行更詳細的說明。
4.3 二進制表示的優勢
除了位爆破所提供的簡潔性 之外,二進制表示 在概率推理中還具有重要的額外優勢。
首先,許多混合概率模型涉及對連續隨機變量進行算術運算 。由于我們使用的是定點數的二進制表示,像 +、*、/、< 這樣的算術操作會被編譯為作用于二進制數上的布爾公式(類似于計算機體系結構中的 ALU 電路)。
這種表示方式使得概率推理(特別是我們采用的知識編譯方法 )能夠識別并利用算術運算中存在的結構信息,例如在一個計算過程中各個輸出位之間的條件獨立性 。
最近的研究 [Cao et al. 2023] 描述了這類編譯過程,并通過實驗證明了其在整數運算中的優勢;而 HyBit 則將這些優勢擴展到了定點數 的計算上。
此外,二進制表示 也使期望值和方差的計算更加高效。
對一個定義在 2? 個值上的分布來說,樸素地計算其期望值和方差需要分別計算這 2? 個值的概率。
而使用按位表示(bitwise representation),我們只需計算每個比特位的概率,從而實現了指數級的效率提升 。
需要注意的是,在最壞情況下,對于任意的混合概率程序,獲得其對應的二進制表示的 OBDD 本身可能是關于位數呈指數增長 的。
但對于混合伽馬分布類 ,這種轉換的復雜度是線性的 (見定理10)。
我們在以下兩個定理中形式化了期望值與方差的計算方式,并在附錄中提供了相應的證明。
定理12 :設 D 是定義在區間 [0, 2?) 上的一個離散概率分布,該分布以 位二進制形式表示為 (?, ???, ..., ?),則 D 的期望值可以利用期望的線性性質如下計算:
5 實驗評估
我們評估了在真實概率程序中使用位爆破方法的實用性。為了實現這一目標,我們進行了相關實驗,以探討以下問題:
- Q1
:HyBit 相較于現有的推理算法表現如何?見第5.1節
- Q2
:分段近似的效果如何?見第5.2節
我們將 HyBit 與兩類近似推理算法進行了對比評估。
采樣方法
我們對比了以下代表性算法:
WebPPL 的拒絕采樣(rejection sampling)
使用 Metropolis Hastings 核的 MCMC 采樣
SMC 采樣
Stan 的 HMC 采樣
離散化方法
我們還與以下兩類基于離散化的推理算法進行了比較:
AQUA
GuBPI [Beutner 等 2022; Huang 等 2021]
比較不同概率編程系統的性能是一項具有挑戰性的任務,因為性能直接受到程序結構的影響。
為公平起見,我們在每個系統中編寫了等效的基準程序,并盡最大努力對其進行了優化。
本節及其后續章節中的表格報告了隨機算法在10次運行中的平均絕對誤差 ;對于其他推理算法,則報告單次運行的結果。
所有實驗均采用單線程 方式執行,運行環境為配備 2.4 GHz CPU 和 512 GB RAM 的服務器。
表2報告了 HyBit 與其他近似推理算法在性能評估中的對比結果。
我們選取了所有 Psi [Gehr 等 2016] 曾評估過的混合與連續基準程序 ,并額外補充了一些來自現有研究 [Huang 等 2021] 的相關基準程序。
我們盡最大努力通過解析方法 或使用計算機代數系統 為這些基準程序計算出真實值 (ground truth)。
在本評估中,我們僅包含那些我們可以可靠地獲得真實值 的基準程序。
對于所有基準程序,我們都報告了各算法相對于真實值的絕對誤差 。對于返回非布爾值的基準程序,我們計算了每種方法所得到的期望值的絕對誤差 。
我們報告的是各推理算法在20分鐘超時限制內 所能達到的最小誤差 。
在所有基準測試中:
HyBit 將混合伽馬分布替換為其 正確的位爆破分布 ;
其他分布則被替換為 線性分段近似 版本,即 ?,? 分布。
每個基準程序所使用的位數 (bits)和分段數 (pieces)也在表2中列出。
為了在這些基準程序上運行 Stan,我們使用了 SlicStan [Gorinova 等 2020] 來生成經過離散隨機變量邊緣化處理 后的 Stan 程序。
對于所有 WebPPL 基線方法,我們在20分鐘內對所有采樣算法均使用默認設置 ,并盡可能多地進行采樣。
如表2所示,使用位爆破的 HyBit 在所有基準測試中與現有方法表現相當,甚至在 19 個中的 11 個 基準上表現更優。對于其余 8 個基準,HyBit 的準確率也非常接近。
- AQUA
僅在 4 個基準上表現更好;
- GuBPI
未能獲得良好的準確率,這主要是因為其采用的枚舉式離散化方法在高精度下擴展性較差。
WebPPL 和 Stan (通過 SlicStan 實現自動邊緣化)支持大多數基準程序,但在限定時間內未能達到較高的準確率。這是因為基于采樣的算法具有隨機性,在有限時間內無法從真實的后驗分布中獲取足夠多的樣本。
5.1.2 精確推理算法
表3比較了 HyBit 與一個使用代數方法進行精確推理 的概率編程系統 Psi [Gehr 等 2016]。
我們盡最大努力對基準程序進行了優化翻譯以適配 Psi 的運行環境。Psi 通常會輸出一個符號表達式 ,我們需要將其輸入 Mathematica 進行進一步簡化。
然而,這些代數表達式的計算和簡化并非易事:
Psi 在其中 6 個基準上超時;
Mathematica 也無法簡化其中 4 個基準的結果。
相比之下,HyBit 能夠處理全部 19 個基準程序 ,因為它將計算歸約為對布爾隨機變量的離散推理,并對推理查詢進行近似處理。
5.2 分段近似的效果如何?
我們分析了在使用不同數量的分段對連續分布進行近似時,性能與準確率之間的權衡 。
圖14 展示了在 四個基準程序 上,隨著分段數的增加,運行時間與準確率的變化趨勢 ,并且展示了在不同位寬(bitwidth)下的表現。
隨著線性分段數的增加,運行時間 先下降后上升 。
準確率則如下面四個子圖所示, 隨著分段數的增加而提升 。這是因為當我們增加分段數量時,連續分布被更精確的位爆破分布所替代。
然而,在達到某個“最佳點”之后,準確率的提升是以運行時間增加為代價的。
附錄中提供了更多實驗,進一步證明了相較于基于中心極限定理 的方法,使用分段近似 更具優勢。
6 相關工作
概率編程一直是語義和推理兩個研究方向上的熱門領域 [Dahlqvist 等 2023; Milch 等 2005]。
本節將 HyBit 與相關工作進行對比。
總體而言,HyBit 的關鍵創新在于提出了一種位爆破方法 ,用于對混合概率程序進行簡潔的離散化表示 。這一方法顯著區別于現有工作。
離散化方法
一些早期的方法通過對連續或混合概率程序進行離散化,并通過枚舉所有離散化值 來估計后驗分布 [Huang et al. 2021],但這種方法在很多情況下無法擴展到足夠高的精度。
一種早期的離散化技術也使用了二進制表示 [Claret et al. 2013],但其方法并不是我們所說的“位爆破”(bit blasting),因為它并不簡潔,且最終生成的表示形式仍與離散點的數量成正比。
最近的一項工作則使用離散化來為概率程序的后驗分布提供上下界估計 [Beutner et al. 2022]。
混合概率程序的推理算法
還有一些研究專門針對混合概率程序:
- Leios
[Laurel and Misailovic 2020] 通過對混合程序進行 連續化處理 ,以利用現有連續推理算法的能力。
- HyBit
則采取相反的策略,對混合程序進行 離散化處理 ,這有助于在面對具有復雜離散結構的混合程序時提升推理的可擴展性。
- SPPL
通過將混合程序轉換為特定的表示形式來進行推理 [Saad et al. 2021],但這種表示方式限制了所能支持的混合程序類型。例如,SPPL 不支持對連續隨機變量進行算術運算,而 HyBit 可以做到這一點。
此外,概率邏輯編程語言也被擴展用于支持混合模型,例如通過使用 區間軌跡 (interval traces) [Gutmann et al. 2011]。
一些 PPL 的推理算法通過生成閉式代數表達式 來編碼概率分布,并使用符號計算技術執行精確推理 [Gehr et al. 2016; Hur et al. 2014; Narayanan et al. 2016]。
然而,這些系統在表達能力和可處理程序范圍上都存在一定的限制,如表3所示。
基于路徑的推理算法
PPL 中常見的一類推理算法是操作語義式的 (operational):它們通過使用隨機變量的具體取值來記錄程序的執行路徑。
這類方法包括:
各種采樣算法(如拒絕采樣、MCMC)
變分近似方法 [Bingham et al. 2019; Carpenter et al. 2017; Chaganty et al. 2013; Dillon et al. 2017; Goodman et al. 2008; Hur et al. 2015; Kucukelbir et al. 2015; Mansinghka et al. 2013, 2018; Minka et al. 2014; Pfeffer 2007; Saad and Mansinghka 2016; Tristan et al. 2014; van de Meent et al. 2015; Wingate and Weber 2013; Wood et al. 2014]
像拒絕采樣和 MCMC 這樣的采樣算法雖然通用,但也存在已知的局限性,例如難以處理多模態分布和低概率證據(見第2節)。
更高級的技術如哈密頓蒙特卡洛(HMC)和變分近似方法可以緩解這些問題,但它們要求函數具有連續性和幾乎處處可微性,因此必須通過邊緣化消除所有離散結構。
使用二進制表示
“位爆破”是一種廣泛應用于軟件驗證 領域的技術,常被約束求解器用來基于二進制表示進行算術推理 [Bruttomesso 和 Sharygina 2009]。
最近也有研究在整數上的概率程序推理中采用數字的二進制表示 [Cao et al. 2023],以便利用該表示中的條件獨立性。
HyBit 中的位爆破受到這些技術的啟發,但其目標不同,因而采用了截然不同的技術路線:即開發簡潔且在許多情況下可證明正確 的連續概率分布近似方法。
7 結論與未來工作
在本研究中,我們闡明了為混合概率程序 開發新型推理方法的必要性。
我們提出了位爆破 (bit blasting)方法:通過將混合概率程序進行簡潔的離散化表示 ,然后使用離散推理算法 對其進行分析。
我們定義了一類連續密度函數——混合伽馬密度 (mixed-gamma densities),并證明對于這類分布,位爆破不僅具有簡潔性 ,而且相較于顯式的離散化方法是正確的 (sound),并且在推理上是可證明高效 的。
在此基礎上,我們提出了一種新的概率編程語言 HyBit ,它基于位爆破技術,采用了一種全新的針對混合程序的推理算法。我們通過實驗展示了 HyBit 在性能上優于現有的近似推理算法。在未來的工作中,我們希望進一步擴展能夠被正確位爆破的分布類別 。我們計劃研究如何將 HyBit 擴展以支持分層貝葉斯模型 (hierarchical Bayesian models)。此外,我們還計劃提升其易用性,不再要求用戶為每個概率程序手動指定超參數 。我們也感興趣于探索將 HyBit 與其他推理方法進行集成,以結合它們各自的優勢,從而支持更廣泛的混合概率程序。
原文鏈接:https://github.com/Tractables/Dice.jl/tree/hybit
特別聲明:以上內容(如有圖片或視頻亦包括在內)為自媒體平臺“網易號”用戶上傳并發布,本平臺僅提供信息存儲服務。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.