BENCHMARKING PREDICTIVE CODING NETWORKS
– MADE SIMPLE
預測編碼網絡基準測試——化繁為簡
https://arxiv.org/pdf/2407.01163?
摘要
在本研究中,我們解決了機器學習中預測編碼網絡(Predictive Coding Networks, PCNs)的效率與可擴展性問題。為此,我們開發了一個專注于性能與簡潔性的庫——稱為PCX,并利用它實現了一套大規模的標準基準測試任務,供整個社區用于實驗。由于該領域的大多數研究通常各自提出自己的任務和架構,缺乏相互之間的比較,并且主要集中在小規模任務上,因此一個簡單、快速的開源庫以及一套全面的基準測試將能夠解決上述所有問題。
隨后,我們使用現有的 PCN 算法以及生物可解釋深度學習領域流行方法的改編版本,在這些基準測試上進行了廣泛的實驗。這一切使我們得以:
在更復雜的數據集上測試比文獻中常用的更大規模的架構;
在所有提供的任務和數據集上取得了新的最先進(state-of-the-art)結果;
明確指出了當前 PCN 的局限性,從而為未來的重要研究方向提供了依據。
為了推動整個社區朝著我們認為該領域最重要的開放問題之一——“可擴展性”努力,我們公開了代碼、測試用例和基準測試。
1 引言
1999 年,Rao 和 Ballard(1999)提出了預測編碼(Predictive Coding, PC)的一種形式化表達,用于模擬大腦中的分層信息處理機制。最近人們意識到,這一框架可以用來通過一種生物學上合理的訓練規則來訓練神經網絡(Whittington & Bogacz, 2017)。這引發了不同的研究方向,有的研究探索 PC 網絡的有趣特性(Song et al., 2024;Alonso et al., 2022),有的則提出改進型模型以提升特定任務的表現(Salvatori et al., 2024;Ororbia & Kifer, 2022)。
然而,這些研究往往不進行跨工作的結果對比,而且大多集中在小規模實驗上。因此,該領域忽視了我們認為最關鍵的一個開放問題:可擴展性。
可擴展性問題被忽視的原因有多個方面。首先,這是一個困難的問題,目前尚不清楚為何到目前為止,預測編碼僅能在一定規模下表現得像經典的反向傳播梯度下降一樣好,這個規模大致是使用 CIFAR10 數據集訓練的小型卷積模型(Salvatori et al., 2024)。理解這一點將有助于我們開發穩定學習過程的正則化技術,從而在更復雜任務上取得更好的表現。
其次,由于缺乏專用庫的支持,PC 模型的運行速度極其緩慢:在一個小型卷積網絡上完成一次完整的超參數搜索可能需要數小時。
第三,缺乏統一的框架使得可重復性和持續改進變得困難,因為實現細節或代碼很少被公開。
在本文中,我們首次嘗試解決這些問題,提出了三個方面的貢獻:工具(tool)、基準測試(benchmarking)和分析(analysis)。
工具(Tool):我們發布了一個名為PCX的開源庫,用于加速預測編碼(Predictive Coding)的訓練過程。該庫基于JAX(Bradbury 等,2018)開發,并通過借鑒 PyTorch 風格的熟悉語法,提供了用戶友好的接口,使學習曲線盡可能平緩。我們還提供了詳盡的教程。此外,該庫完全兼容Equinox(Kidger & Garcia, 2021),這是一個面向深度學習的 JAX 擴展庫,確保了代碼的可靠性、可擴展性以及與前沿研究發展的兼容性。
同時,PCX 支持 JAX 的即時編譯(Just-In-Time, JIT),使其在效率上表現優異,既便于開發也利于執行預測編碼網絡。與現有庫相比,PCX 在性能方面具有明顯優勢。
基準測試(Benchmarking):我們提出了一套統一的任務、數據集、評估指標和網絡架構,作為未來測試各類預測編碼變體性能的基礎框架。我們所提出的任務是計算機視覺領域的標準任務:圖像分類與圖像生成。
我們所選用的模型與數據集遵循兩個原則:
允許研究人員從最簡單的任務(如 MNIST 上的前饋網絡)逐步過渡到更復雜的任務;
便于與文獻中相關領域(如平衡傳播和目標傳播,Equilibrium and Target Propagation,Scellier & Bengio, 2017;Bengio, 2014)進行比較。為此,我們選取了一些這些研究論文中常用且一致使用的模型。
在學習算法方面,我們考慮了標準預測編碼(Standard PC)、增量式預測編碼(Incremental PC,Salvatori 等,2024)、結合朗之萬動力學的預測編碼(PC with Langevin dynamics,Oliviers 等,2024),以及 nudged PC(如 Eqprop 文獻中所采用的方法,Scellier & Bengio, 2017;Scellier 等,2024)。值得注意的是,這是首次將 nudging 類算法應用于預測編碼模型中。
分析(Analysis):我們在多個基準測試中取得了預測編碼領域的最先進(State-of-the-Art, SOTA)結果,并首次表明,預測編碼能夠在 CIFAR100 和 Tiny ImageNet 這類更復雜的數據集上表現出色,其性能可與反向傳播相當。在圖像生成任務中,我們進行了彩色圖像數據集上的實驗,超越了以往僅限于 MNIST 和 FashionMNIST 的研究范圍。
我們對實驗結果進行了深入討論,指出了改進的方向,其中最主要的問題是:如何提升預測編碼在非常深層模型中的泛化能力。我們也報告了在這種情況下對信用分配機制(credit assignment)的分析,以更好地理解某些失敗案例背后的原因。
最后,在補充材料中,我們詳細解釋了幫助我們達到 SOTA 結果的超參數、技術與技巧,為本領域的研究人員提供一份實用的操作指南(cookbook)。
2 相關工作(RELATED WORKS)
Rao 與 Ballard 的預測編碼(PC):最相關的工作是那些在深度學習背景下,探索標準預測編碼的不同性質或優化算法的研究。這些研究的公式靈感來源于 Rao 和 Ballard 的原始工作(Rao & Ballard, 1999)。例如,有研究探討了其聯想記憶能力(Salvatori 等,2021;Yoo & Wood, 2022;Tang 等,2023;2024),訓練貝葉斯網絡的能力(Salvatori 等,2022;2023b),以及解釋或改進其優化過程的理論成果(Millidge 等,2022a;b;Alonso 等,2022)。這些研究成果要么提升了模型在不同任務中的表現,要么研究了可能從使用預測編碼網絡(PCNs)中受益的各種特性。
預測編碼的變體(Variations of PC):在文獻中存在多種預測編碼算法的變體。重要的例子包括“偏差競爭”(biased competition)和“除法輸入調制”(divisive input modulation)(Spratling, 2008),以及“神經生成編碼框架”(neural generative coding framework)(Ororbia & Kifer, 2022)。后者已被應用于多個強化學習與控制任務中(Ororbia & Mali, 2023;Ororbia 等,2023),并擁有一個基于 JAX 的開源庫 NGCLearn。
關于不同預測編碼算法如何隨時間演變——從信號處理到神經科學——我們推薦閱讀(Spratling, 2017);而對于專注于機器學習應用的最新綜述,可參見(Salvatori 等,2023a)。
值得一提的是,神經科學領域關于預測編碼的原始文獻也已從 Rao 和 Ballard 的工作發展出一種通用理論,即通過概率和變分推斷建模大腦信息處理的“自由能原理”(free energy principle)(Friston, 2005;Friston & Kiebel, 2009;Friston, 2010)。
受神經科學啟發的深度學習方法(Neuroscience-inspired deep learning):另一類相關工作是將神經科學方法應用于機器學習的方法,例如平衡傳播(equilibrium propagation)(Scellier & Bengio, 2017),這是與預測編碼最相似的方法之一(Laborieux & Zenke, 2022;Millidge 等,2022a)。其他能夠訓練類似規模模型的方法還有目標傳播(target propagation)(Bengio, 2014;Ernoult 等,2022;Millidge 等,2022b)和 SoftHebb(Moraitis 等,2022;Journé 等,2022)。
目標傳播(targetprop)和平衡傳播(eqprop)這兩個社區在其研究論文中經常使用相似的架構來測試他們的方法。因此,在我們的基準測試工作中,部分所提出的架構與它們一致,以促進更直接的對比。
還有一些方法與預測編碼差異較大,例如“僅前向方法”(forward-only methods)(Kohan 等,2023;N?kland, 2016;Hinton, 2022),以及使用一組專門的權重進行誤差反向傳播的方法(Lillicrap 等,2014;Launay 等,2020)。
3 背景與符號說明(BACKGROUND AND NOTATION)
預測編碼網絡(Predictive Coding Networks, PCNs)是一種分層高斯生成模型,包含L 層。每一層都建模一個多變量分布,該分布由前一層的激活所參數化,而前一層又依賴于模型參數θ = θ?, θ?, θ?, ..., θ?和模型狀態h。
令h? ∈ h表示第l層隨機變量向量H?的一個實現值,則我們有如下的似然關系:
(注:原文此處可能緊接著給出具體的數學表達式,比如關于預測誤差或生成模型的形式。由于你提供的段落未包含完整公式,我將根據上下文進行補充翻譯示意。)
即,每一層通過對下一層(或輸入)的預測誤差進行建模來定義其概率分布。這種層級結構使得模型能夠通過自上而下和自下而上的信息交互來進行推斷和學習。
其中,θ?是可學習的權重參數,用于參數化變換f?,而Σ?是一個協方差矩陣,在本工作中將始終設為單位矩陣。例如,如果θ? = (W?, b?)且f?(h???, θ?) = σ?(W?h??? + b?),那么第l?1層的神經元通過一個線性操作(后接非線性映射)連接到第l層的神經元,這與全連接層類似。
直觀上,θ是模型中所有可學習權重的集合,而h = {h?, h?, ..., h?}是依賴于數據點的潛在狀態,包含了對給定觀測值的抽象表示。
訓練:在監督學習設置下,訓練的目標是學習給定輸入-輸出對(x, y)之間的關系。在預測編碼(PC)中,這是通過最大化我們生成模型的聯合似然來實現的,其中潛變量h?和h?分別固定為輸入和標簽:
Pθ(h | h? = x, h? = y) = Pθ(h? = y, ..., h?, h? = x)。
這一目標通過最小化所謂的變分自由能 F(variational free energy)來完成(Friston 等,2007):
我們將由公式(2)描述的第一步稱為推斷階段(inference phase),將第二步稱為學習階段(learning phase)。
在實際操作中,我們并不是僅對單個樣本對(x, y)進行訓練,而是對被劃分為小批量(mini-batches)的數據集依次進行訓練,以更新模型參數。此外,推斷和學習過程都是通過對變分自由能(variational free energy)進行梯度下降來近似實現的。
在推斷階段,首先將h初始化為一個初始值h???,然后對其進行T 次迭代的優化。接著,在學習階段,我們利用新計算出的h值對權重θ執行一次更新。
關于變分自由能相對于h和θ的梯度如下所示:
然后,將一個新的數據批次輸入模型,重復上述過程,直到模型收斂。
如公式(3)所示,每個狀態和每個參數都是使用局部信息進行更新的,因為梯度僅依賴于突觸前誤差??和突觸后誤差????。這正是預測編碼(PC)與反向傳播(BP)相比的主要區別之一:PC 是一種局部學習算法,因此被認為在生物學上更具合理性。
在附錄 A 中,我們對本段中所闡述的概念提供了算法層面的描述,并重點說明了這些數學公式是如何在PCX庫中被轉化為實際代碼的。
評估(Evaluation):給定一個測試樣本x?,我們將h? = x?固定,并再次使用公式(3)中的狀態梯度來計算在h? = x?條件下最可能的潛狀態 h |h?=x? 。我們將這一過程稱為*判別模式 (discriminative mode)。
在實際中,對于判別型網絡,通過這種方式計算出的潛狀態值等價于前向傳播的結果,即對所有層l > 0設置初始值h???? = μ????,因為這對應于自由能F的全局最小值(Frieder & Lukasiewicz, 2022)。
生成模式(Generative Mode):預測編碼網絡(PCNs)也可以用于執行無監督學習任務。給定一個數據點x,其目標是將x的信息壓縮為一種潛在表示,概念上類似于變分自編碼器(VAE)的工作方式(Kingma & Welling, 2013)。
這種壓縮是通過將狀態向量h?固定為該數據點,并運行推斷過程實現的——也就是說,我們通過對h進行梯度下降來最大化Pθ(h | h? = x)。壓縮后的表示即為收斂時(或實踐中經過T步后)得到的h?的值。
如果我們正在進行模型訓練,則隨后會對模型參數進行一次梯度更新,以最小化公式(1)中的變分自由能,這與我們在監督學習中所做的操作一致。
圖 1(a) 展示了預測編碼網絡在判別模式和生成模式下的訓練方式示意圖。
4 實驗與基準測試(EXPERIMENTS AND BENCHMARKS)
我們提出的基準測試是一組標準化的模型、數據集和測試流程,這些在預測編碼研究中曾被廣泛使用,但方式并不統一。為了進行全面評估,我們在多個計算機視覺數據集上測試了具有不同復雜度的模型——包括前饋層、卷積/反卷積層,以及文獻中多種學習算法。
本節分為兩個部分,分別對應判別模式(監督)和生成模式(無監督)的推理任務:前者我們專注于監督分類任務,后者則進行無監督生成任務。圖1展示了這兩種模式的示意圖。
對于每一類實驗,我們都進行了大規模的超參數搜索。附錄 B 和 C 中提供了復現實驗所需的詳細信息,并討論了在此次大規模搜索過程中獲得的經驗教訓。
為了提供全面的評估,我們測試了多個計算機視覺數據集:MNIST(LeCun & Cortes, 2010)、FashionMNIST(Xiao 等, 2017)、CIFAR10/100(Krizhevsky 等, 2009)、CelebA(Liu 等, 2018)以及 Tiny ImageNET(Le & Yang, 2015);同時測試了從簡單到復雜的多種模型結構,以及文獻中多種學習算法。
結果是基于 5 次隨機種子平均得到的,在使用判別模型時見表 1,在使用生成模型時見表 2。需要注意的是,除了最近在 CelebA 上的一次嘗試(Sennesh 等, 2024),這是首次在如 CelebA、CIFAR100 和 Tiny ImageNet 這類數據集上對采用局部消息傳遞的 PCN 進行測試。
算法:我們考慮了文獻中出現的多種學習算法:
- 標準預測編碼(Standard PC),已在背景部分介紹;
- 增量式預測編碼(Incremental PC, iPC)(Salvatori 等, 2024),這是一種近期提出的方法,其特點是權重參數與潛變量在每一步都同步更新;
- 蒙特卡洛預測編碼(Monte Carlo PC, MCPC)(Oliviers 等, 2024),通過在推斷過程中應用未調整的朗之萬動力學(unadjusted Langevin dynamics)實現;
- 正向 nudging(Positive Nudging, PN),其中目標是通過對輸出朝原始 one-hot 標簽方向做小擾動得到;
- 負向 nudging(Negative Nudging, NN),目標是對輸出遠離標簽方向做擾動,并以相反方向更新權重;
- 中心 nudging(Centered Nudging, CN),交替使用正向和負向 nudging 的訓練輪次(Scellier 等, 2024)。
其中,PC、iPC 和 MCPC 將用于生成模式,而 PC、iPC、PN、NN 和 CN 用于判別模式。更詳細的描述請參見圖 1 和補充材料。
4.1 判別模式(DISCRIMINATIVE MODE)
我們通過比較預測編碼(PC)與反向傳播(BP)在圖像分類任務中的表現來測試 PCN 的性能,使用平方誤差損失(SE)和交叉熵損失(CE),并按照 Pinchetti 等(2022)的描述調整能量函數。
在 MNIST 和 FashionMNIST 上,我們使用了包含 3 層隱藏層、每層 128 個神經元的前饋網絡;而在 CIFAR10/100 和 Tiny ImageNET 上,則比較了 ResNet 和 VGG 類模型(He 等, 2016;Simonyan & Zisserman, 2014)。
結果:表 1 顯示,在最復雜的任務上,表現最好的算法是各種 nudging 方法(PN、NN 和 CN)。其中,CN 幾乎總是最優的,這一結果與 Eqprop 文獻中的發現一致(Scellier 等, 2024)。
唯一 nudging 方法表現不佳的情況出現在 Tiny ImageNet 的 VGG7 上,此時 PC-CE 表現優于它們。然而,PC-CE 的結果仍不如 CN 在 VGG5 上的表現。
另一方面,新近提出的 iPC 在小型架構上表現良好,是 MNIST 和 FashionMNIST 上表現最好的方法,但在訓練大型架構時性能下降。
總體而言,深度不超過 7 的模型表現可與反向傳播相當,而更深模型的表現則明顯落后。
關于深度的討論:一個有趣的觀察是,所有 PC 的最佳結果均來自使用 VGG5 架構,性能趨勢為:VGG5 > VGG7 > VGG9 > ResNet,如圖 2 所示。
相比之下,反向傳播訓練的模型則呈現相反趨勢:像 VGG9 這樣的深層模型表現優于 VGG5。在 ResNet18 的實驗中也觀察到了類似的趨勢,PCNs 的測試準確率顯著較低,沒有任何模型接近 VGG5 的性能。而使用反向傳播訓練的 ResNet18 模型則超過了此前所有測試過的 VGG 模型,進一步突出了兩種方法在可擴展性方面的差距。
未來的研究應調查這種現象的原因,因為要擴展到更復雜的數據集,必須使用更深的架構。第 5 節中我們將分析可能的原因,并比較不同算法的實際運行時間(wall-clock time)。
4.2 生成模式(GENERATIVE MODE)
在本節中,我們測試了預測編碼網絡(PCNs)在圖像生成任務中的表現。我們進行了三種不同類型的實驗:
- 從后驗分布中生成圖像
- 通過對學習到的聯合分布進行采樣來生成圖像
- 聯想記憶檢索
在第一種情況下,我們將一張測試圖像y輸入訓練好的模型,運行推斷過程以計算其壓縮表示x?(即收斂時潛變量h?的值),然后通過將h? = x?輸入模型進行一次前向傳播,生成重建圖像? = h?。
我們使用的模型包含三層,在與自編碼器對比時,我們比較的是具有三層編碼器/解碼器結構的模型(總共六層)。對于 MNIST 和 FashionMNIST 數據集,我們使用的是前饋層;而對于 CIFAR10 和 CelebA,則使用卷積/反卷積層。
表 2 和圖 3 中的結果顯示,PC 在更復雜的任務上略優于反向傳播(BP)。在這種情況下,iPC是表現最好的算法,這可能是因為所用模型規模較小,從而提升了穩定性。
接著,我們測試了 PCN 是否能夠學習并從一個復雜的概率分布中采樣。MCPC(蒙特卡洛預測編碼)通過在每個神經元的激活更新中加入高斯噪聲擴展了標準 PC。這一變化使得 PCN 能夠像變分自編碼器(VAE)一樣學習并生成樣本。
這種改進將 PCN 的推斷方式從變分近似轉變為基于朗之萬動力學(Langevin dynamics)的后驗分布蒙特卡洛采樣。通過對所有狀態h?不加限制地執行帶有噪聲的推斷更新,可以從學習到的聯合分布Pθ(h)中生成數據樣本。
圖 4 展示了 MCPC 使用 iris 數據集(Pedregosa 等,2011)學習多模態分布的能力,并展示了其在 MNIST 上的生成樣本。
在與 VAE 的對比中,兩種模型生成的樣本質量相似。MCPC 實現了更低的 FID 分數(MCPC:2.53±0.17 vs. VAE:4.19±0.38),而 VAE 則實現了更高的 inception score(VAE:7.91±0.03 vs. MCPC:7.13±0.10)。
在聯想記憶(Associative Memory, AM)實驗中,我們測試了模型在接受一個不完整或被破壞的圖像輸入后,是否能夠成功重建原始訓練圖像的能力,這一設置借鑒了之前的研究工作(Salvatori 等,2021)。
圖 5 展示了一個具有 2 個隱藏層、每層 512 個神經元的預測編碼網絡(PCN)在面對噪聲或遮擋損壞圖像時所獲得的結果。
在表 3 中,我們研究了隨著隱藏層數量增加時的記憶容量。當 MSE(均方誤差)不超過 0.005 時,重建圖像與原始圖像之間在視覺上沒有明顯差異。
為了評估效率,我們在TinyImageNet 的 500 張樣本上訓練了一個具有 5 個隱藏層、每層 512 個神經元的 PCN,訓練時批量大小為 50,每次訓練包含 50 次推斷迭代。在Nvidia V100 GPU上,每個 epoch 的訓練耗時為0.40 ± 0.005 秒。
討論:結果表明,預測編碼(PC)不僅可以執行生成任務,還能使用僅解碼器結構完成聯想記憶任務。通過推斷過程,PCN 能夠在其潛狀態中編碼復雜的概率分布,并可用于執行多種不同的任務,正如我們所展示的那樣。
雖然這突出了 PCN 在生成模式下的靈活性,但其代價是由于需要進行多次推斷步驟,因此計算成本更高。
5 分析與度量指標(ANALYSIS AND METRICS)
在本節中,我們報告了一些我們認為對理解當前使用預測編碼(PC)訓練網絡的狀態和挑戰至關重要的度量指標,并在適當的情況下將其與使用梯度下降和反向傳播(backprop)訓練的標準模型進行比較。
我們首先進行的一項研究分析了網絡狀態h的初始化方式如何影響模型性能。在文獻中,這些狀態通常被初始化為零、通過高斯先驗隨機初始化(Whittington & Bogacz, 2017),或通過一次前向傳播來初始化。
最后一種技術是機器學習論文中的首選方法,因為它將每一層(特別是輸出層以外的內部層)的誤差??=L = 0設為零。這使得預測誤差僅集中在輸出層,從而等同于平方誤差(SE)。
為了比較這三種初始化方法,我們在 FashionMNIST 數據集上訓練了一個 3 層的前饋模型。結果如圖 6(a) 所示,表明前向初始化確實是更優的方法,盡管隨著推斷迭代次數T增加,不同初始化方法之間的性能差距會縮小。
能量傳播(Energy propagation):將模型的總誤差集中在最后一層,會使推斷過程難以將這種“能量”再傳播回前面的層。正如圖 6(b) 所示,我們觀察到即使經過多次推斷步驟后,最后一層的能量仍比輸入層高出幾個數量級。
一個快速在整個網絡中傳播能量的簡單方法是:在更新狀態時使用等于1.0的學習率,這樣不會產生任何能量不平衡,如圖 6(d) 所示。
然而,無論是圖 6(b) 的結果,還是我們在第 4 節中進行的大規模實驗分析都表明:最佳性能始終是在狀態學習率γ明顯小于1.0的情況下實現的。這引發了一個問題:更好的初始化或優化技術是否可以帶來更平衡的能量分布,從而獲得更有效的權重更新?
為了更好地理解能量傳播與模型性能之間的關系,我們分析了測試準確率以及相鄰層之間能量比值隨狀態學習率γ的變化情況。結果如圖 6(c,d) 所示,顯示較小的學習率雖然帶來了更好的性能,但也導致了各層之間巨大的能量不平衡。
一方面,當γ = 1時,第一個隱藏層的能量與最后一層相近;而當γ = 0.01時,前者比后者低約六個數量級。
另一方面,使用γ = 1訓練的模型性能明顯更差。這些結果表明,當前的訓練設置傾向于在不同層之間產生巨大的能量不平衡,這一問題會導致當模型深度增加時,梯度呈指數級減小。
我們在附錄 D 中提供了其他數據集上的實現細節和實驗結果。
訓練穩定性(Training stability):我們觀察到權重優化器與隱藏層維度對模型性能的影響之間存在關聯。為了更深入地研究這一點,我們使用不同隱藏層維度、狀態學習率γ以及優化器訓練了前饋預測編碼網絡(PCNs),并將結果展示在圖 7 中。
結果顯示,當使用Adam優化器時,隱藏層的寬度會顯著影響訓練過程穩定所需的γ學習率取值范圍。有趣的是,這種現象在使用SGD優化器時并未出現,也不出現在使用反向傳播訓練的標準網絡中。
這種與反向傳播(BP)之間的行為差異是出乎意料的,它表明預測編碼網絡(PCNs)需要更好的優化策略。盡管在我們的實驗中,AdamW仍然是最優選擇,但它可能成為更大規模架構發展的瓶頸。
6 庫、資源與實現細節(LIBRARY, RESOURCES AND IMPLEMENTATION DETAILS)
在本節中,我們將介紹我們用于實驗并開源發布的工具PCX。
PCX是基于JAX開發的,專注于性能和通用性,并建立在以下三個核心理念之上:兼容性(Compatibility)、模塊化(Modularity) 和高效性(Efficiency)。
兼容性(Compatibility)
PCX 借鑒了Equinox(Kidger & Garcia, 2021)的設計哲學,即將模型視為PyTrees(JAX 中的數據結構)。因此,它完全兼容 JAX 的函數式編程范式,也能夠無縫對接許多為 JAX 構建的庫和工具,例如diffrax(Kidger, 2021)和optax(DeepMind 等, 2020)。
這意味著將最新的深度學習研究成果快速集成到 PCX 中將是十分便捷的。同時,PCX 也提供了命令式的面向對象接口,允許研究人員以類似 PyTorch 的方式構建預測編碼網絡(PCNs)。
模塊化(Modularity)
借助面向對象的抽象設計,我們構建了可組合的基本模塊來創建 PCN,主要包括:
- 模塊類
(Module class):表示抽象的能量模型;
- 向量化節點
(Vectorized nodes):存儲狀態h;
- 優化器
(Optimizers):執行預測編碼網絡中的推斷與學習過程;
- 標準層
(Standard Layers):如全連接層、卷積層等。
我們在本文中展示的每一個基準測試任務都可以通過靈活組合和配置這些模塊來實現。
高效性(Efficiency)
PCX 廣泛依賴于Just-In-Time(JIT)即時編譯技術。根據我們的初步基準測試,在對 PCN 進行 JIT 編譯后,其運行速度最高可提升50 倍。
我們認為這種顯著差異來源于預測編碼本身的特性:相比反向傳播,預測編碼需要進行多個較小的操作(即每一層都要執行 T 次推斷步驟),因此更容易受到“急切執行模式”下函數調用開銷的影響。
PCX 提供了一個統一的接口,用于在多種任務上測試不同版本的預測編碼算法。我們的模塊化代碼庫未來可以輕松擴展,以支持新的預測編碼變體,正如我們已經展示了與現有各種變體和訓練技術的完整兼容性一樣。
這與諸如 (Song, 2024) 和 (Ororbia & Kifer, 2022) 中使用的單一體系或低級方法形成了鮮明對比。
6.1 計算資源與限制(COMPUTATIONAL RESOURCES AND LIMITATIONS)
我們將 PCX 的實現與另一個廣泛使用的開源庫(Song, 2024)進行了實際運行時間(wall-clock time)的比較,該庫被用于多篇預測編碼相關研究(Song 等, 2024;Salvatori 等, 2021;2022;Tang 等, 2023)。此外,我們也將其與使用相同架構但通過反向傳播訓練的網絡進行了對比(同樣使用 PCX 實現,以確保公平比較)。
表 4 報告了在 A100 GPU 上平均 5 次實驗所測得的每個 epoch 的運行時間。我們甚至優于其他方法,如 Eqprop:在 CIFAR100 上使用相同的架構時,作者報告一個 epoch 需要約110 秒,而我們在相同硬件上僅需約 5.5 秒(Scellier 等, 2024)。
不過,這并不是一種嚴格的“同類比較”,因為原作者更關注的是模擬電路中的仿真效果,而非最大化 GPU 利用率。
局限性(Limitations)
PCX 的效率還可以通過完全并行化所有操作進一步提升。目前,JIT 尚無法自動并行化各層之間的計算;這一問題理論上可以通過 JAX 的vmap原語解決,但在當前情況下只有當所有層具有相同維度時才可行,這在現實中并不實用。
為了測試不同模型超參數如何影響訓練速度,我們采用了一個前饋模型,并多次訓練,每次按倍數增加某一特定超參數。結果如圖 8 所示,顯示層數 L和推斷步數 T是兩個最顯著影響訓練時間的參數。
理想情況下,應該只有 T 影響訓練時間,因為推斷是一個本質上順序化的流程,無法并行化;但實際情況并非如此,訓練時間隨層數線性增長。
詳細信息請參見附錄 G。
7 討論(DISCUSSION)
本工作的主要貢獻是推出并開源發布了一個名為PCX的庫,該庫可用于使用預測編碼網絡(PCNs)執行深度學習任務。它的高效性依賴于JAX 的 Just-In-Time(JIT)即時編譯技術,以及經過精心設計、能夠充分利用 JIT 性能的基礎組件(primitives)。
我們這個庫的第二個優勢在于其直觀的設計,特別適合那些已經熟悉如 PyTorch 等其他深度學習框架的用戶。結合我們發布的大量教程,新用戶將可以輕松上手并使用 PC 來訓練神經網絡。
我們隨后使用PCX對文獻中不同模型和訓練算法進行了廣泛的對比研究,測試了大量參數組合與激活函數,得出了具有代表性的結果。
在實驗結果方面,我們展示了:在使用小型/中型架構(如 VGG7)的前提下,預測編碼網絡的表現可與使用反向傳播(BP)訓練的標準深度學習模型相媲美。然而,一旦放寬這一限制,當模型規模增大時,預測編碼的表現就無法跟上 BP 的擴展能力。
在補充材料中,我們還提供了嚴謹的研究,詳細說明了能量在 PCN 中隨時間流動的方式、訓練穩定性分析,并展示了 PCN 如何對分布外(out-of-distribution)數據進行分類,以及通過使用跳躍連接(skip connections)來訓練極深網絡的可能解決方案。
附錄(APPENDIX)
在本部分中,我們提供了有關實驗是如何進行以及結果是如何獲得的詳細信息。我們采用更具描述性的敘述方式來傳達基本概念,而所有用于復現實驗的具體細節則包含在所提供的代碼中,以及接下來的各個小節中。
每個小節都將鏈接到與所述實驗相對應的確切目錄。
A PCX — 簡要介紹(A BRIEF INTRODUCTION)
在本節中,我們將通過描述在預測編碼框架下訓練和評估一個前饋分類器所需的主要構建模塊,來說明PCX的核心理念。如需更詳細和完整的解釋,請參考庫中examples 文件夾下的教程筆記本(tutorial notebooks)。
在第 3 節中,我們將預測編碼網絡(PCNs)定義為具有參數θ = {θ?, ..., θ?}和狀態h = {h?, ..., h?}的模型。在PCX中,我們將一個模型劃分為兩個主要組成部分:
- 層(layers):即傳統的深度學習變換,例如 “Linear”(全連接層)或 “Conv2D”(卷積層);
- 節點(vodes,vectorized nodes) : 即向量化節點,用來存儲表示狀態
h?的神經元數組。
一個預測編碼網絡(PCN)的定義如下:
在__call__
方法中,我們將輸入x通過網絡進行前向傳播。請注意,每次我們調用一個vode(向量化節點)時,實際上是在其中存儲了激活值u?(這樣我們之后可以計算與該 vode 相關的能量??2),并返回其狀態h?(即x = vode(u)
對應于vode.set("u", u); x = vode.get("h")
)。
在訓練過程中,標簽y被提供給模型,并通過覆蓋最后一層 vode 的狀態h???來將其固定為標簽值。
需要注意的是,在訓練和評估階段,第一個 vode 的狀態都會被固定為輸入x,因此我們并沒有顯式定義它(即我們避免計算Pθ?(h?),因為它是常數),而是直接將輸入x傳遞給第一層的變換。
類pxc.EnergyModule
提供了一個.energy()
函數,用于根據公式(1)計算變分自由能F。
我們可以通過調用pxf.value_and_grad
(這是對 JAX 同名函數的封裝)來按照公式(3)計算狀態和參數的梯度。
在分別定義了兩個優化器optim_w
(用于參數)和optim_h
(用于狀態)之后,我們可以如下定義針對一對數據(x, y)
的訓練過程:
import pcx.utils as pxu
import pcx.functional as pxf
關于上述代碼的一些說明:
JAX(Bradbury 等,2018)是一個函數式編程庫,而PCX 并非完全函數式。PCX 中的模塊是PyTrees(JAX 中的標準數據結構),它借鑒了另一個流行的 JAX 庫Equinox(Kidger & Garcia, 2021)的設計哲學,并且 PCX 的模塊與 Equinox 完全兼容。
然而,模塊的狀態由 PCX 自動管理,使得每個參數變換都會被自動追蹤。用戶可以通過將參數作為關鍵字參數傳入來選擇啟用這種行為(如上面的例子所示)。而位置參數則會被 PCX 忽略,此時用戶需要像在 JAX 或 Equinox 中那樣自行管理其狀態。
pxf.value_and_grad
允許指定一個Mask 對象,用于標識要對哪些參數執行梯度計算。在上面的例子中,我們首先計算自由能F相對于狀態(VodeParam
)的梯度,然后是相對于模型權重(LayerParam
)的梯度。在
train
函數中,我們使用pxu.step
將模型狀態設置為pxc.STATUS.INIT
,以執行狀態初始化。在 PCX 中,默認使用的是前向初始化方法,但也可以很方便地指定其他初始化方式。此外,pxu.step
還用于清除 PCN 的緩存,這些緩存用于存儲中間值(如激活值u?)。庫中的實際示例是針對數據的 mini-batch 進行的,因此在實際實驗中,上述所有操作都通過vmap向量化處理。
對于判別模式下的評估函數,我們只需對 PCN 執行一次前向傳播,該過程會將所有層的誤差?? = 0。
B 判別模式實驗(DISCRIMINATIVE EXPERIMENTS)
模型:我們對三種模型進行了實驗:MLP、VGG-5 和 VGG-7。這些模型的詳細架構見表 5。
對于每種模型,我們使用以下不同的算法進行了實驗:
標準預測編碼結合交叉熵損失(PC-CE)/均方誤差損失(PC-SE):已在背景部分介紹。
正向 nudging 預測編碼(PC-PN):
與標準的均方誤差損失預測編碼(PC-SE)不同,在標準方法中輸出被固定為目標值;而在帶有 nudging 的預測編碼中,我們是將輸出“推向”目標。具體來說,我們將最后一層的表示h?固定為:
4 中心 nudging 預測編碼(PC-CN):中心 nudging(Scellier 等,2024)常用于平衡傳播(equilibrium propagation)中,以改善并穩定正向和負向 nudging 的性能。它通過取這兩種方法產生的梯度的平均值得出。在這里,我們通過在不同訓練輪次中隨機交替使用正向或負向 nudging 來近似這一行為。這樣,訓練模型可以同時受益于兩種方法,而不會帶來額外的計算成本。
5 增量式預測編碼(iPC):這是一種簡單且近期提出的方法,其特點是權重參數與潛變量在每一步都同步更新(Salvatori 等,2024)。
6
標準反向傳播結合交叉熵損失 (BP-CE)/均方誤差損失 (BP-SE):這是神經網絡中最常用的信用分配方式。模型通過鏈式法則計算損失函數對網絡權重的梯度來進行訓練。
實驗:
MLP 模型的基準結果是在MNIST和Fashion-MNIST數據集上獲得的;
VGG-5 模型的結果是在CIFAR-10、CIFAR-100和Tiny ImageNet上獲得的;
VGG-7 模型的結果則是在CIFAR-100和Tiny ImageNet上獲得的。
數據按照表 6 中的方式進行了歸一化處理。
對于CIFAR-10、CIFAR-100和Tiny ImageNet訓練集的數據增強,我們以 50% 的概率應用了隨機水平翻轉(random horizontal flipping)。此外,我們還采用了針對每個數據集設置不同的隨機裁剪(random cropping)操作:
對于CIFAR-10和CIFAR-100,圖像在每側填充 4 個像素后被隨機裁剪為32×32分辨率;
對于Tiny ImageNet,圖像則直接被隨機裁剪為56×56分辨率,不進行填充。
在Tiny ImageNet的測試集中,我們使用中心裁剪(center cropping)提取56×56分辨率的圖像,同樣不進行填充,因為 Tiny ImageNet 原始分辨率為64×64。
模型的超參數是根據表 7 所示的搜索空間確定的。表 1 中展示的結果是使用 5 個隨機種子并在最優超參數下訓練得到的。
關于優化器和學習率調度器:
我們使用帶有動量的小批量梯度下降(SGD)作為狀態h的優化器;
使用AdamW(Loshchilov & Hutter, 2017)結合權重衰減(weight decay)作為參數θ的優化器;
同時,我們對θ的學習率采用了一個預熱余弦退火調度器(warmup-cosine-annealing scheduler),且不進行重啟(without restart)。
結果:本研究中呈現的所有結果均使用前向初始化(forward initialization)方法獲得,該技術通過對一個與輸入數據形狀相同的零張量(zero tensor)進行一次前向傳播來初始化模型參數。
此外,在我們的實驗中,我們對T的取值范圍進行了限制,以確保在訓練時間方面與反向傳播(BP)進行公平比較。
更大的T值意味著對狀態h進行更多輪次的優化,這可能會提升模型性能,但同時也會增加計算成本和延長訓練時間。為了保持與 BP 的可比性,我們將T的搜索范圍限制在使得訓練時間與基于 BP 的訓練相當的區間內。
動量(momentum)起到了顯著作用。在圖 9 中,我們展示了使用不同動量值訓練的 VGG-7 模型在 CIFAR-100 數據集上的準確率,其中包括不使用 nudging(圖 9a)和使用 nudging(圖 9b)的情況。
從圖 9 可以明顯看出,選擇合適的動量值可以顯著提升模型的準確率。通過對比圖 9a 和圖 9b,我們可以觀察到不同的訓練算法具有不同的最佳動量值。
使用nudging進行訓練時,其最佳動量值通常高于不使用 nudging 的情況;而使用負向 nudging(negative nudging)的最佳動量值又高于使用正向 nudging(positive nudging)的情況。
這些最佳動量值之間的差異突顯了根據所采用的具體訓練算法和 nudging 方法,仔細調整動量超參數的重要性。
作為參考,各種任務和模型下的最佳動量參數值可在PCX 庫的 example/discriminative_experiments 文件夾中找到。
激活函數在提升模型準確率方面也起著關鍵作用。
對于使用交叉熵損失(Cross-Entropy Loss)的模型,“HardTanh”激活函數是更好的選擇;
對于使用均方誤差損失(Mean Squared Error Loss)且不使用 nudging的模型,“LeakyReLU”激活函數通常表現更好;
在使用正向 nudging的情況下,最佳激活函數會因模型架構而異;
而在使用負向 nudging時,“GeLU”激活函數是最合適的選擇。
Nudging 提升了模型性能。圖 10 展示了在有或沒有 nudging 的情況下,狀態h的學習率與模型準確率之間的關系。
從圖中可以看出,在不使用 nudging的情況下(紅色點),模型在較低的學習率下表現更好;
而在使用 nudging的情況下(紫色和藍色點),無論是正向還是負向 nudging,模型都能在更高的學習率下獲得更好的準確率。
此外,圖 9b 還展示了動量與準確率之間的關系。我們可以看到,在應用 nudging 后,模型在更高的動量值下也能取得更好的結果。
我們認為這正是 nudging 能夠提升性能的原因:能夠在不犧牲準確率的前提下使用更高的學習率和動量值,這是 nudging 的一大顯著優勢,因為它可以帶來更快的收斂速度和更好的泛化性能。
C 生成實驗(GENERATIVE EXPERIMENTS) C.1 自編碼器(AUTOENCODER)
自編碼器是一種網絡結構,它學習如何將高維輸入盡可能準確地壓縮到一個低得多的維度空間中,這個空間被稱為瓶頸維度(bottleneck dimension)或隱藏維度(hidden dimension)。
因此,基于反向傳播的自編碼器由兩個部分組成:
- 編碼器(Encoder):將輸入從原始高維空間壓縮到瓶頸維度;
- 解碼器(Decoder):從瓶頸維度重建原始輸入。
訓練過程中使用原始輸入與重建輸入之間的均方誤差(MSE)作為損失函數,以無監督的方式訓練整個自編碼器網絡。
預測編碼(Predictive Coding, PC)則省去了傳統自編碼器中編碼器的部分。具體來說,僅使用自編碼器的解碼器部分,其中一層預測編碼層充當瓶頸維度,并作為解碼器的輸入。此外,在解碼器的每一層之后都會插入一個 PC 層。
基于預測編碼的自編碼器工作方式如下:
最后一層 PC 層的能量函數在創建時被設定為 MSE(均方誤差)。在 PCX 中,平方誤差是默認的能量函數。該平方誤差會在輸入的所有維度上進行求和,并在批次上取平均,從而近似等于標準 MSE 損失,最多相差一個常數倍數。
最后一層 PC 層(第 L 層)的當前狀態 h? 被固定為原始輸入數據,這意味著在推斷過程中h? 不會被更新。
由于最后一層的能量現在表示的是預測圖像μ?與作為h?存儲的原始輸入之間的 MSE 損失,因此在推斷過程中,所有非最后一層的 PC 層的狀態h?都將被更新,包括代表瓶頸維度的那一層,以最小化這個 MSE 損失。
一旦完成推斷過程,瓶頸維度對應的PC層的狀態將收斂到原始輸入的一個壓縮表示。
模型:蒙特卡洛預測編碼(Monte Carlo Predictive Coding, MCPC)是一種可用于生成式學習的預測編碼變體。
MCPC 與標準預測編碼(PC)的區別在于其帶有噪聲的神經元動態機制。不同于標準 PC 中神經元活動收斂到自由能的一個極值點,MCPC 的神經元活動執行的是帶有噪聲的梯度下降,用于進行蒙特卡洛采樣。
當提供輸入時,MCPC 的噪聲神經元活動會對給定感知輸入下生成模型的后驗分布進行采樣;
而當沒有輸入時,神經元活動則對模型參數所編碼的生成模型本身進行采樣。
具體來說,MCPC 的神經元動態基于如下的朗之萬動力學(Langevin dynamics):
一個MCPC 模型是按照蒙特卡洛期望最大化(Monte Carlo Expectation Maximization, MCEM)方案進行訓練的,該過程迭代執行以下兩個步驟:
- MCPC 的神經元活動對給定數據下的模型后驗分布進行采樣
- 模型參數根據這些后驗樣本進行更新,以提高模型在這些樣本下的對數似然函數值
在實際操作中,我們僅運行有限步數的 MCPC 推斷過程,之后使用一個后驗樣本來更新模型參數。這一方式類似于變分自編碼器(VAE)中更新模型參數的方式。
訓練完成后,通過不對任何神經元施加約束(即不固定輸入層),并記錄輸入神經元(即訓練時被固定為數據的那些神經元)的活動來生成已訓練模型的樣本。
在有限次數的狀態更新步驟之后記錄這些活動。此過程對每個數據樣本重復執行。
實現細節
PCX 中的 MCPC 實現使用了一個帶有噪聲的 SGD 優化器來更新狀態h。與使用 SGD 或 Adam 優化器的標準 PC 不同,MCPC 使用的優化器將向梯度中加入噪聲并與 SGD 結合使用。
為了使噪聲的方差能夠適當隨著學習率和動量進行縮放,其形式需精心設計,如公式(4 - 6)所示。
實驗設置
所有MCPC 實驗均使用具有平方誤差(SE)損失的前饋模型。
狀態層h?的 SE 損失還通過一個方差參數σ2??進行了縮放。引入這個額外參數是為了防止高斯層h?的方差遠大于數據本身的方差,從而避免影響學習效果。
在無條件學習與生成任務中,h? 層在訓練和生成過程中都不被固定;
而在 MNIST 上的有條件學習任務中,h? 層在訓練和生成時都被固定為標簽。
我們在 Iris 數據集上訓練了一個維度為[2 x 64 x 2]
的模型,激活函數為tanh
,使用默認參數值如下:
狀態學習率 γ = 0.01
狀態動量 = 0.9
噪聲狀態方差 σ2mcpc = 1
參數學習率 lrθ,參數衰減 = 0.0001
使用 Adam 作為參數優化器
層方差 σ2?? = 0.01
批次大小 = 150
在訓練過程中使用500 步狀態更新,在生成階段使用10000 步狀態更新。
MNIST 上的無條件學習任務
我們在 MNIST 上訓練了維度為[30 x 256 x 256 x 256 x 784]
的模型。
MCPC 和 VAE 的模型超參數是通過表 10 所示的搜索空間確定的,分別用于優化FID 分數和inception score。
具體最優參數值請參考代碼倉庫。
訓練過程中使用1000 步狀態更新,生成階段使用10000 步狀態更新。
MNIST 上的有條件學習任務
我們在 MNIST 上訓練了維度為[2 x 256 x 256 x 256 x 784]
的模型。
本任務中使用的標簽被固定到h? 層,用于表示圖像對應的是偶數還是奇數數字。
模型超參數通過表 10 所示的搜索空間確定。
訓練過程中使用1000 步狀態更新,生成階段使用10000 步狀態更新。
實驗結果 圖 12 顯示了在最大化 inception score 的超參數設置下,已訓練模型所生成的樣本。
C.3 聯想記憶(Associative Memories)
本節描述了聯想記憶任務的實驗設置。
模型(Model)一個生成式PCN(generative PCN)首先在從Tiny ImageNet數據集中采樣的n張圖像上進行訓練,直到其參數收斂。隨后,將訓練圖像的一個受損版本呈現給模型的感覺層(hL),并對所有層(包括感覺層)運行推理過程(?hl),直到收斂。需要注意的是,在掩碼實驗中,圖像的上半部分在整個推理過程中保持不變。
直觀上,假設模型在訓練期間已經通過固定每個訓練樣本的感覺層值,最小化了其自由能,那么它就形成了以這些訓練樣例定義的吸引子(attractors),因此會傾向于將受損圖像“優化”回這些能量吸引子中。
實驗(Experiments)在這里,基準結果是使用 Tiny ImageNet 數據集獲得的,圖像被加入了高斯噪聲(標準差為0.2),或對圖像下半部分進行了遮擋(掩碼處理)(示例見圖5)。我們通過改變模型大小和需要記憶的訓練樣本數量,來研究模型的記憶容量。
具體來說,我們使用了一個架構為 [512, d, d, 12288] 的生成式 PCN,其中 d = [512, 1024, 2048](12288 是展平后的 Tiny ImageNet 圖像維度),并設置了 n = [50, 100, 250]。對于每一個 d 和 n 的組合,我們在以下超參數范圍內進行了搜索:
參數學習率 lrθ ∈ {1 × 10?? + k · 5 × 10?? | k ∈ Z, 0 ≤ k ≤ 18}
狀態學習率 γ ∈ {0.1 + k · 0.05 | k ∈ Z, 0 ≤ k ≤ 18}
訓練階段的推理步數 Ttrain ∈ [20, 50, 100]
回憶階段的推理步數 Trecall ∈ [50000, 100000]
我們固定激活函數為 Tanh,訓練輪數為 500 輪,批量大小為 50。表3中的結果是在5個隨機種子下使用最優搜索得到的超參數得出的。
D 能量與穩定性(Energy and Stability)
本節描述了第??節的實驗設置,提供了在其他數據集上的復現結果以及消融實驗(ablations)。
D.1 能量傳播(Energy Propagation)
我們在多個數據集上測試一系列模型網格(grid of models),以考察模型中的能量傳播情況。我們測試了 FashionMNIST、Two Moons 和 Two Circles 數據集。其中,Two Circles 數據集尤其有趣,因為不良的能量分布會直觀地導致線性歸納偏置(inductive bias)(我們主要學習的是一個單層網絡)。這種線性歸納偏置對 Two Circles 的性能影響最大(線性模型準確率約為 50%),相比之下,FashionMNIST(約 83%)和 Two Moons(約 86%)受影響較小。
實驗設置(Experimental Setup)我們訓練了一個包含兩個隱藏層的前饋 PCN 模型網格。我們在三個數據集上進行了訓練:主文中報告的 FashionMNIST,以及額外的 Two Moons 和 Two Circles 數據集。對于所有模型,我們訓練 8 個 epoch,并使用 T = 8 次推理步數。狀態(states)通過 SGD 進行優化,并采用前向初始化(forward initialization)。
模型網格覆蓋以下參數組合:
權重學習率 lrθ ∈ {1 × 10??, 1 × 10??, ..., 1}
狀態學習率 γ ∈ {1×10?3, 3×10?3, 1×10?2, 3×10?2, 1×10?1, 3×10?1, 1}
激活函數 f ∈ {LeakyReLU, HardTanh}(前者無界,后者有界)
優化器:AdamW 或帶動量的 SGD,動量 m ∈ {0.0, 0.5, 0.9, 0.95}
隱藏層寬度:
FashionMNIST:{512, 1024, 2048, 4096}
Two Moons 和 Two Circles:{128, 256, 512, 1024}
我們為 FashionMNIST 的所有實驗重復了 3 個隨機種子,其他數據集則重復了 10 個隨機種子。
結果(Results) 主文中的圖6(左)展示了在網格中表現最佳的模型在訓練結束時最后一個 batch 中各層的平均能量。圖6(中左)比較了動量為 0.9 的 SGD 與 AdamW 的表現。該圖基于激活函數“HardTanh”和寬度為 1024 的模型繪制。我們在圖13中展示了其他激活函數和寬度組合的結果。
我們觀察到,在所有條件下,SGD 通常更偏好小到中等大小的狀態學習率,而 AdamW 則更傾向于更小的狀態學習率。由于各層之間的能量分布不均,AdamW 尤其可能難以擴展到更深的網絡結構。
此外,我們還觀察到 AdamW 的性能方差更大,尤其是在較寬的層中,這一點我們在第 ?? 節的“訓練不穩定(Training Instability)”段落以及下文中進一步討論。
圖6(右)是基于所有使用 AdamW 訓練的模型繪制的。許多具有高狀態學習率的模型發生了發散,我們只繪制了準確率 > 0.5 的模型。
接下來我們展示在 Two Moons 和 Two Circles 數據集上的實驗結果。圖14b、14a 和 14c 展示了 Two Moons 對應于圖6的結果,圖15b、15a 和 15c 展示了 Two Circles 的結果。這些結果與 FashionMNIST 上的結果非常相似:即使在 T 次推理步驟之后,能量仍集中在最后一層。
然而,在 Two Circles 的例子中,我們實際上觀察到了前面層的訓練效果:雖然由于誤差傳播,能量最初有所上升(但數量級仍遠低于后面的層),隨后能量又下降了。能量比例一致表明,對于那些表現良好的狀態學習率 γ,能量傳播仍然較差。
正如預測的那樣,Two Circles 數據集的結果方差顯著更大,尤其是當狀態學習率較小時。
D.2 訓練穩定性(Training Stability)
我們測試了一系列 PCN 模型,以分析模型寬度、狀態學習率和權重優化器之間的相互作用。
實驗設置(Experimental Setup)我們在 FashionMNIST(如上所述)和 Two Moons 數據集上訓練模型。我們訓練了具有兩個隱藏層的前饋 PCN,并使用“LeakyReLU”作為激活函數,在一組參數網格上進行實驗。所有模型均訓練 8 個 epoch。隱藏層的寬度為 {32, 64, ..., 4096}。
狀態變量通過 SGD 進行訓練,推理步數 T = 8,狀態學習率 γ ∈ {1×10??, 3×10??, ..., 0.3}。權重則通過 SGD 或 Adam 優化器更新:在 FashionMNIST 上使用學習率為 0.01,在 Two Moons 上使用學習率為 0.03。兩種優化器對權重都使用了 0.9 的動量。
我們還訓練了使用相同超參數的反向傳播(BP)基線模型。對于 FashionMNIST,每個實驗重復 3 次隨機初始化;對于 Two Moons,則重復 10 次。
結果(Results) 我們將圖7(對應于 FashionMNIST)的結果復制到 Two Moons 數據集上,結果見圖16。我們觀察到 Two Moons 上的現象與 FashionMNIST 類似:Adam 優化下,優化穩定性強烈依賴于隱藏層的寬度,而 SGD 在任一數據集上都沒有表現出這種效應。
這進一步支持了我們在第 ?? 節中的結論:雖然 Adam 是更優的優化器,但其與網絡寬度和狀態學習率之間的交互效應(width × γ)可能會阻礙使用 Adam 的 PCN 的擴展能力。因此,PCN 的優化方法仍需研究社區的進一步關注。
消融實驗(Ablation)我們還在 FashionMNIST 上進行了一個消融實驗。在之前的實驗中,隱藏層寬度發生了變化,這不僅引入了隱藏層絕對大?。瓷窠浽獢盗浚┑淖兓哺淖兞司W絡中隱藏層的相對大小(因為輸入層和輸出層的大小在所有實驗中保持不變)。因此,我們設計了一個新的實驗:在 FashionMNIST 上,我們增大圖像尺寸,并將標簽向量用 0 填充,使得所有層的寬度相等。
其他所有實驗變量保持不變。結果如圖17所示,并延續了圖7和圖16中觀察到的趨勢:我們發現優化器與網絡寬度之間存在如前所述的交互作用。因此,僅考慮層寬度的相對變化不足以解釋該問題,我們可以得出結論:使用 AdamW 時,層的絕對大小在優化穩定性中也起到了重要作用。
ResNets 在此部分,我們結合 ResNets18 的實驗結果,討論能量傳播的相關發現。我們已經表明,節點的較低學習率會損害能量傳播,并且當隱藏維度較大時,AdamW 優化器表現不佳。
為此,我們使用 SGD 和較大的節點學習率訓練了 ResNets18,并將其性能與主文中報告的結果進行比較。然而,這些結果無法與表1中報告的結果直接對比,因為在 CIFAR10 數據集上使用 SGD 訓練的 ResNets18 在采用 PC 和 iPC 方法時,準確率分別僅為 39.9% 和 43.2%。
為了更好地理解不同超參數對模型最終測試準確率的影響,我們在圖18中展示了它們的重要性圖(importance plots)。這些重要性是通過擬合一個以超參數為數據點、準確率為標簽的隨機森林回歸器,并提取特征重要性計算得到的。
E 將跳躍連接引入 VGG19(Skip Connections into VGG19)
跳躍連接(Skip connections):我們研究了將跳躍連接引入 VGG19 架構的效果,以提升其在 CIFAR10 圖像分類任務中的表現,結果顯示測試準確率從25.32% 顯著提升至 73.95%。消失梯度問題是在深度預測編碼(PC)模型中一個顯著的挑戰,隨著網絡深度的增加,這一問題變得更加突出,阻礙了誤差向更早層的傳遞,影響了學習效果。為了解決這個問題,我們引入了跳躍連接,使得梯度可以跳過多層直接傳播,從而增強了梯度流動和整體的學習性能。
結果(Results)我們改進后的 VGG19 模型在特征提取階段從較早的一層引入了一個跳躍連接,輸出經過展平和線性層調整后,在分類階段重新整合進網絡。該模型在 CIFAR10 數據集上進行了嚴格的訓練與評估,采用了標準的預處理技術,如歸一化、數據增強(水平翻轉和旋轉)。通過詳細的超參數調優,我們為有和沒有跳躍連接的模型都找到了最佳配置,探索了不同的優化器、學習率、動量值和權重衰減設置。結果顯示,引入跳躍連接顯著提升了模型性能,總結見表11。
圖19 展示了在 CIFAR10 數據集上,使用三種不同隨機種子、相同超參數的情況下,帶跳躍連接與不帶跳躍連接的 VGG19 模型在30個訓練輪次中的測試準確率變化趨勢。
F 預測編碼網絡的性質(Properties of Predictive Coding Networks)
本節描述了第 F.1 節的實驗設置,并展示了利用 PCN 分類器的自由能來區分分布內(In-Distribution, ID)和分布外(Out-of-Distribution, OOD)數據的有效性(Liu 等人,2020)。我們展示了如何在 PCN 下計算多個數據集的負對數似然(Grathwohl 等人,2020)。此外,我們還分析了在狀態最優前后,最大 softmax 值與能量值之間的關系。我們在多個數據集上進行比較,以驗證我們的結果,并展示僅通過一個訓練好的 PCN 分類器,就能即插即用地用于 OOD 檢測,并基于 softmax 和能量得分的不同百分位數,研究其接收者操作特征曲線(ROC 曲線)。
F.1 自由能與分布外數據(Free Energy and Out-of-Distribution Data)
借助 PCX,我們可以方便地檢查和分析 PCN 的多種性質。在此,我們使用自由能F來區分由于語義分布偏移(semantic distribution shift)導致的分布內(ID)和分布外(OOD)數據(Liu 等人,2020),并用于計算某個數據集的似然值(Grathwohl 等人,2020)。
這種情況可能發生在樣本來自不同、未見過的類別時,例如在 MNIST 設置下出現 FashionMNIST 樣本(Hendrycks & Gimpel, 2017)。
實驗設置(Experimental Setup)我們在 MNIST 數據集上訓練一個 PCN 分類器,使用具有 3 個隱藏層的前饋 PCN,每層大小為 H = 512,激活函數為 “GELU”,輸出層采用交叉熵損失函數。我們使用早停機制(early stopping)在第 75 輪停止訓練,直到測試誤差收斂。
在訓練過程中,狀態變量通過 SGD 進行優化,推理步數 T = 10,狀態學習率為 γ = 0.01,無動量。權重則使用帶動量 mθ = 0.9 的 SGD 優化器更新,權重學習率設為 lrθ = 0.01。
在測試階段推理時,我們對狀態變量持續優化至收斂,推理步數設為 T = 100。
為了理解 PCN 預測的置信度,我們將 ID 和 OOD 樣本的能量分布與分類器生成的 softmax 分數分布進行對比。我們通過以下方式計算 ID 和 OOD 樣本在 PCN 分類器下的負對數似然:
我們在MNIST數據集上進行實驗作為分布內(ID)數據,并將其與多個分布外(OOD)數據集進行對比,包括notMNIST、KMNIST、EMNIST(字母類)以及FashionMNIST。
簡要來說,圖20a中的結果表明,一個訓練好的 PCN 分類器可以有效地:
- 即插即用地
(out-of-the-box)識別分布外樣本,而無需專門為此目的進行訓練(Yang 等人,2021);
在狀態變量h優化之前,生成與 softmax 值初始相關的能量得分(energy scores),用于區分 ID 和 OOD 樣本。
然而,在對狀態變量進行了 T 次推理步數的優化之后,ID 和 OOD 樣本的得分變得不再相關,尤其是對于那些 softmax 值較低的樣本,如圖20b所示。
為了驗證這一觀察結果,我們還展示了最具挑戰性的樣本(即得分最低的前25%)的 ROC 曲線。如圖20c所示,基于概率(即基于能量)的得分能夠更可靠地判斷樣本是否為 OOD。
其他數據集上的實驗細節和結果詳見附錄 F。此外,我們還在下方提供了EMNIST(字母類)和KMNIST數據集上的更多、更詳細的結果。
結果。下面我們根據各種圖表支持的實驗,簡要解釋這些額外的結果。從圖21中我們可以看到,在測試時狀態優化之前和之后能量是如何分布的。可以看出,所有OOD數據集的初始能量和最終能量都明顯大于ID數據集(MNIST)。
在圖22中,我們進一步通過疊加狀態優化前后能量的直方圖,展示了每個OOD數據集的能量分布與同分布(ID)數據集能量的對比情況。我們可以看到,通過繪制直方圖,出現了一種模式,即大多數OOD數據樣本與ID數據樣本之間沒有重疊,這支持了能量值可以用于OOD檢測的觀點。
接下來在圖23中,我們展示了當比較ID與OOD數據集的softmax分數時,這種模式可能呈現的樣子。可以看出,softmax分數在判斷樣本是否為OOD時提供的信息較少,從ID和OOD樣本在softmax值范圍上有更大的重疊就可以看出這一點。
在圖24中,我們進一步研究了狀態收斂前后softmax分數與能量值之間的關系。該圖顯示,在推理之前能量與softmax分數之間存在高度相關性,但在收斂之后表現出一種非線性關系,尤其是對于較小的值,模型更加不確定的情況下更為明顯。這表明,softmax分數和能量值在判斷哪些樣本應降低置信度方面并不完全一致。
在圖25中,我們展示了所有數據集在推理前后能量分布的情況。每個箱線圖代表不同的場景和不同的數據集。此外,我們還計算了每個數據集的負對數似然(NLL),并將其作為箱線圖標簽的一部分進行展示。我們觀察到,在所有OOD數據集中,初始和最終的能量值都明顯高于MNIST(ID)數據集。此外,可以看出,同分布數據的能量得分方差更小,這一點可以從MNIST數據在箱線圖須范圍之外幾乎沒有離群樣本來體現。最后,每種場景下的NLL值也驗證了這一觀察結果,MNIST數據的似然顯著高于OOD數據分布。
最后,在圖26中,我們展示了PCN如何用于將樣本分類為屬于ID數據還是某種OOD數據。我們使用PCN分類器的能量來進行OOD檢測,并展示了基于能量的檢測所得到的ROC曲線優于通過softmax分數生成的ROC曲線。當通過選取分數和能量的25%百分位來觀察最具挑戰性的樣本時——即那些能量或softmax值較小、反映PCN模型最不確定的樣本——這一觀察結果變得更加明顯。
G 計算資源
圖8是通過使用一個由兩層組成、每層64個神經元的小型前饋PCN模型,在批次大小為32的隨機噪聲數據上進行訓練(使用隨機噪聲是為了避免因將訓練數據加載到GPU而產生的額外開銷),并設置T = 8步所得。然后,對每個參數獨立進行縮放,以測量其對總訓練時間的影響。通過這種方式獲得的每個模型都訓練了5個epoch,并報告了平均時間。在我們所有的時間測量中,跳過了第一個epoch,以避免包含JIT編譯時間。實驗結果是在一塊GTX TITAN X顯卡上獲得的,結果顯示即使在消費級GPU上也具有實現并行化的潛力。
原文鏈接: https://arxiv.org/pdf/2407.01163?
特別聲明:以上內容(如有圖片或視頻亦包括在內)為自媒體平臺“網易號”用戶上傳并發布,本平臺僅提供信息存儲服務。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.