Only Strict Saddles in the Energy Landscape of Predictive Coding Networks?
預測編碼網絡的能量景觀中只有嚴格鞍點嗎?
https://arxiv.org/pdf/2408.11979
摘要
預測編碼(PC)是一種基于能量的學習算法,在更新權重之前對網絡活動進行迭代推理。最近的研究表明,由于其推理過程,PC 可能在比反向傳播更少的學習步驟中實現收斂。然而,這些優勢并不總是能夠觀察到,并且 PC 推理對學習的影響在理論上尚未被充分理解。在這里,我們研究了在網絡活動的推理平衡狀態下 PC 能量景觀的幾何結構。對于深度線性網絡,我們首先證明平衡狀態下的能量實際上是一個重新縮放的均方誤差損失(MSE),其中包含一個與權重相關的重新縮放因子。然后我們證明,該損失中許多高度退化(非嚴格)的鞍點(包括原點)在平衡能量中變得更加容易逃離(嚴格)。我們的理論通過在線性和非線性網絡上的實驗得到了驗證。基于這些和其他結果,我們推測平衡能量的所有鞍點都是嚴格的。總體而言,這項工作表明,PC 推理使損失景觀更加良性,并對梯度消失具有更強的魯棒性,同時也突出了將 PC 擴展到更深模型的基本挑戰。
1 引言
作為一種源自大腦功能一般原則的學習機制,預測編碼(PC)近年來發展成為一種局部學習算法,可能為反向傳播(BP)提供生物學上合理的替代方案 [32, 31, 43]。使用 PC 訓練的深度神經網絡(DNNs)在小型到中型機器學習任務上表現出與 BP 相當的性能,包括分類、生成和記憶關聯 [31, 43, 41]。PC 網絡(PCNs)也非常靈活,支持任意計算圖 [45, 10]、混合和因果推理 [44, 59] 以及時間預測 [35]。
與 BP 不同,類似于其他基于能量的算法 [例如 49, 38],PC 在權重更新之前會對網絡活動進行迭代(近似貝葉斯)推理。這最近被稱為“前瞻性配置”的學習原理,被認為是大腦學習信用分配的一種根本不同的方式 [54],其中權重跟隨活動(而不是相反)。雖然 PC 的推理過程帶來了額外的計算成本,但它被認為可以帶來許多學習上的好處,包括更快的收斂速度 [54, 3, 18]。然而,這些加速效果并未在所有數據集、模型和優化器中一致地觀察到 [3],并且 PC 推理對學習的總體影響在理論上尚未被充分理解(見附錄 A.2.1)。
為了填補這一空白,我們研究了 PC 學習所依賴的有效景觀的幾何結構:即在網絡活動達到推理平衡時的權重景觀(定義見 §2.2)。我們的理論基于深度線性網絡(DLNs),這是研究損失景觀的標準模型(見附錄 A.2)。盡管 DLNs 只能學習線性表示,但它們的損失景觀是非凸的,具有復雜的非線性學習動力學,已被證明是理解非線性網絡的有用模型 [例如 48]。與之前的 PC 理論 [3, 2, 18] 不同,我們沒有做出任何額外的假設或近似(見附錄 A.2),并通過實驗證明我們的線性理論也適用于非線性網絡。
對于 DLNs,我們首先證明在推理平衡狀態下,PC 能量只是一個重新縮放的均方誤差(MSE)損失,其中包含一個非平凡的、依賴于權重的重新縮放因子(定理 1)。然后我們將損失函數中的鞍點(最近已有相關研究 [23, 1])與平衡能量中的鞍點進行比較。這類鞍點在神經網絡的損失景觀中普遍存在 [11, 1],主要分為兩種類型:“嚴格”鞍點,其中 Hessian 矩陣是不定的(定義 1);和“非嚴格”鞍點,其中逃逸方向存在于高階導數中 [15, 23, 1]。對于像(隨機)梯度下降(SGD)這樣的 一階方法 來說,非嚴格鞍點尤其成問題,因為它們至少是二階臨界點。雖然 SGD 在嚴格鞍點附近可能會指數級變慢 [12],但它在非嚴格鞍點處可能會有效地卡住 [47, 7](見附錄 A.2 的綜述)。從損失景觀的角度來看,這就是所謂的梯度消失現象 [39, 6]。
相比之下,我們在此證明了許多 MSE 損失中的非嚴格鞍點,特別是零秩鞍點,在任何深度線性網絡的平衡能量中都變成了嚴格的鞍點(定理 2 和 3)。這些鞍點包括原點,其在損失函數中的退化程度(即平坦度)隨著隱藏層數量的增加而增長。我們的理論結果通過在線性和非線性網絡上的實驗得到了強有力的驗證,進一步的實驗還表明,損失函數中的其他(更高秩)非嚴格鞍點在平衡能量中也是嚴格的。基于這些結果,我們推測平衡能量的所有鞍點都是嚴格的。總體而言,這項工作表明,PC 推理使損失景觀更加良性,并對梯度消失具有更強的魯棒性,同時也突出了加快 PC 推理在更深網絡中的基本挑戰。
其余部分的結構如下:在介紹設置之后(§2),我們展示了針對深度線性網絡的理論結果(§3),包括一些示例和對每個結果的詳細實證驗證。然后我們報告了支持我們理論和更廣泛猜想的非線性網絡實驗(§4)。最后,我們討論了本工作的意義和局限性,以及潛在的未來方向(§5)。附錄 A 包括相關工作的回顧、推導、實驗細節和補充結果。所有實驗的代碼可在 https://github.com/francesco-innocenti/pc-saddles 獲取。
1.1 貢獻總結
我們推導出了深度線性網絡(DLNs)在推理平衡狀態下的 PC 能量的精確解 (定理 1),結果表明它實際上是均方誤差損失(MSE)的一個重新縮放版本,其中縮放因子依賴于網絡權重。這一發現糾正了之前文獻中的一個錯誤觀點:即 MSE 損失等于輸出能量 [34](這僅在前饋階段成立),并為進一步研究 PC 能量景觀提供了基礎。我們的理論與實驗結果高度吻合(見圖 1)。
基于上述結果,我們證明了與 MSE 損失不同的是,DLNs 在平衡能量下的原點是一個嚴格鞍點(strict saddle),且該性質與網絡深度無關 。我們對平衡能量在原點處的 Hessian 矩陣進行了明確刻畫(定理 2),并通過在線性網絡上的實驗完美驗證了這一點(見圖 3、圖 4 和圖 8)。
我們進一步證明了,除了原點以外,MSE 損失中的一些非嚴格鞍點(特別是零秩鞍點)在 DLNs 的平衡能量中也變成了嚴格鞍點 (定理 3)。我們通過實驗對其中一個鞍點進行了實證驗證作為示例(見圖 9 和圖 10)。
我們通過實驗表明,我們的線性理論同樣適用于非線性網絡 ,包括在標準圖像分類任務上訓練的卷積網絡。特別地,當初始化接近 MSE 損失中由定理 3 所涵蓋的非嚴格鞍點時,我們發現基于平衡能量的 SGD 比基于原始損失的 SGD 更快逃離這些鞍點(在相同學習率下)(見圖 5 和圖 12)。與 BP 相比,PC 不會出現梯度消失現象(見圖 11)。
我們還進行了額外的實驗(仍包括線性和非線性網絡),結果顯示,PC 也能快速逃離我們理論上未涉及的其他(更高秩)非嚴格鞍點 (見圖 6),從而進一步支持我們的猜想:平衡能量的所有鞍點都是嚴格鞍點
2 前備知識(預備內容)
2.2 預測編碼(PC)
使用預測編碼(PC)訓練的深度神經網絡(DNNs)通常假設一個具有單位協方差的層次化高斯模型,因此我們對線性全連接層采用這種設定:其中每一層的平均活動 ze是前一層的線性函數。
在對生成模型做出一些常見的附加假設后,我們可以推導出一個能量函數,通常被稱為變分自由能(variational free energy),該函數是各層預測誤差平方和的形式 [9]。
3 理論結果 3.1 平衡能量作為重新縮放的均方誤差(MSE)
正如 §2.2 中所解釋的,預測編碼網絡(PCN)的權重通常在活動達到平衡后進行更新。我們稱之為平衡能量,并將其簡寫為 。因此,平衡能量是 PC 導航的有效學習景觀,也是我們感興趣的研究對象。事實證明,我們可以推導出深度線性網絡(DLNs)平衡能量的封閉形式解,這將成為我們后續結果的基礎。
該證明依賴于展開 PC 所假設的層次化高斯模型,以推導出輸出的完整隱式生成模型,而其中的縮放因子 S 來自于 PC 在每一層所建模的方差(詳見 §A.3.2)。圖 1 展示了該理論的極佳實證驗證結果。
直觀上,PC 的推理過程(公式 3)可以被理解為對(MSE)損失景觀進行重塑,使其能夠考慮到各層、依賴權重的方差。這立即引出了一個問題:平衡能量景觀與損失景觀有何不同?
這種重新縮放——也就是 PC 所建模的各層方差——對學習是否有幫助?
在下文中,我們通過比較這兩個目標函數的鞍點幾何結構,對該問題給出了一個部分肯定的回答。
3.2 對原點鞍點(θ = 0)的分析
在這里我們證明,與 MSE 損失不同,對于任意深度的深度線性網絡(DLNs)來說,平衡能量的原點(公式 5,即所有權重都為零的情況,θ = 0)是一個嚴格鞍點 (定義 1)。為此,我們推導出了平衡能量在原點處 Hessian 矩陣的一個顯式表達式。
為了進行直觀比較,我們首先簡要回顧已知的結果:在原點處,對于單隱藏層網絡,損失函數的 Hessian 是不定的 (indefinite);而對于任何更深的網絡,其 Hessian 為零(詳見 §A.3.1 的推導)。
更具體地說,損失函數的原點鞍點是 H1 階 的,隨著網絡深度的增加,它變得越來越退化(平坦),并且更難逃離,尤其是對于像 SGD 這樣的一階方法(見圖 2 中間和右側面板)。
相比之下,我們現在證明:對于任意隱藏層數量的深度線性網絡(DLNs),平衡能量的原點鞍點是嚴格鞍點 。
圖 2 展示了一些簡單的示例來說明這一結果。簡而言之,我們觀察到,當初始化接近原點鞍點時,隨著深度增加,SGD 從損失函數中逃離所需的時間越來越長,而從能量函數中逃離則相對更快。
下面我們更正式地陳述這一結果(在相同學習率的情況下)。
我們發現,平衡能量在原點處的 Hessian 矩陣 實際上是(見 §A.3.3 的推導):
其中 是經驗輸出協方差矩陣。
我們看到,與損失的 Hessian 矩陣(公式 6)不同的是,對于任意數量的隱藏層 H,能量的 Hessian 矩陣的最后一個對角塊是非零的。接下來可以很容易地證明,能量的 Hessian 矩陣總是具有負特征值,因為輸出協方差矩陣是正定的。
圖 3 和圖 4 顯示了在平衡能量原點處,理論 Hessian(公式 8)與數值計算的 Hessian 之間 完全吻合 。我們針對一系列深度線性網絡,在隨機生成的玩具數據集以及更具現實意義的數據集上進行了該計算。
定理 2 證明了對于任意深度的深度線性網絡(DLNs),原點是平衡能量的一個嚴格鞍點。這與均方誤差(MSE)損失形成了鮮明對比,因為在 MSE 損失中,這一性質僅對單隱藏層網絡 H =1(公式 7)成立。該結果預測,在原點附近,給定相同的學習率,(S)GD 應該在平衡能量上比在損失函數上更快逃離鞍點,并且隨著深度的增加,這種差異會更加顯著。圖 2 確認了這一預測在一些簡單的線性網絡上的有效性,而 §4 中的圖 5 和圖 6 清楚地表明,這一結論同樣適用于非線性網絡。
3.3 對其他鞍點的分析
原點是否只是這樣一個特例:在平衡能量中比在損失函數中更容易逃離鞍點?
還是說,這一結果揭示了某種更普遍的現象?
在這里,我們考慮一類特定的、損失函數中的非嚴格鞍點(其中原點是其中之一),并證明它們在平衡能量中也確實變成了嚴格鞍點。
我們在第 4 節中通過實驗探討了其他類型的鞍點,并將它們的理論分析留作未來研究。
具體來說,我們考慮零秩鞍點 (rank zero saddle)。
請注意,定理 2 現在可以看作是定理 3 的一個推論,盡管對于原點我們推導出了完整的 Hessian 矩陣。這一結果也與(MSE)損失形成了鮮明對比:在 MSE 損失中,許多被考慮的臨界點(特別是當有三個或更多權重矩陣為零時)是非嚴格鞍點,這一點已被 [1] 證明。
我們的預測再次是:在這些鞍點附近,給定相同的學習率,PC 應該比使用(S)GD 的 BP 更快逃離鞍點。由于篇幅限制,后續實驗僅以原點作為一個由定理 3(以及定理 2)涵蓋的鞍點示例,但 §A.5 中包含了對另一個(零秩)平衡能量嚴格鞍點的實證驗證(見圖 9、圖 10 和圖 12)。我們的代碼也使得測試其他鞍點相對容易。
4 實驗
在這里,我們報告了線性網絡和非線性網絡的實驗,這些實驗支持了我們的理論結果,同時也驗證了更一般的猜想:平衡能量的所有鞍點都是嚴格鞍點。在所有實驗中,我們使用 BP 和 PC 以相同的 learning rate 訓練網絡,因為我們的目標是驗證平衡能量景觀鞍點幾何的理論。
所有結果的復現代碼可在 https://github.com/francesco-innocenti/pc-saddles 獲取。
首先,我們比較了線性網絡和非線性網絡(包括卷積架構)在標準圖像分類任務上的損失訓練動態,這些網絡使用 SGD 初始化接近原點(詳見 §A.4)。出于計算原因,我們沒有讓 BP 訓練的網絡收斂,這突顯了損失函數的原點鞍點高度退化,特別是對于像 SGD 這樣的一階方法來說,很難逃離。在所有情況下,我們觀察到 PC 逃離原點鞍點的速度明顯快于 BP(見圖 5),并且圖 11 表明 PC 沒有出現梯度消失的現象。當我們初始化接近定理 3 所涵蓋的另一個非嚴格鞍點時,我們得到了幾乎相同的結果(見圖 12)。這些發現支持了我們的理論結果,不僅限于線性情況。
從圖 5 中,我們還觀察到 PCNs 損失動態中的第二個平臺期,這表明存在一個更高秩的鞍點(可能是秩為 1 的鞍點)。這一現象與 [19] 對深度線性網絡(DLNs)描述的“鞍點到鞍點”動力學一致,在該動力學中,對于小初始化,GD 通過一系列鞍點進行過渡,每個鞍點代表一個遞增秩的解。
為了明確測試我們未在理論上研究的更高秩、非嚴格鞍點,我們復制了 [19] 在矩陣補全任務上的一項實驗(參見其圖 1)。具體而言,網絡被訓練以擬合一個秩為 3 的矩陣,這意味著從接近原點開始,GD 訪問了 3 個鞍點(依次為秩 0、秩 1 和秩 2),最終收斂到一個秩為 3 的解,如圖 6 所示。
我們發現,當初始化接近 BP 訪問的任何鞍點時,PC 能夠快速逃離,并且沒有出現梯度消失的現象(見圖 6),這支持了平衡能量的所有鞍點都是嚴格鞍點的猜想。
從圖 5 中,我們還觀察到 PCNs 損失動態中的第二個平臺期,這表明存在一個更高秩的鞍點(可能是秩為 1 的鞍點)。這一現象與 [19] 對深度線性網絡(DLNs)描述的“鞍點到鞍點”動力學一致,在該動力學中,對于小初始化,GD 通過一系列鞍點進行過渡,每個鞍點代表一個遞增秩的解。
為了明確測試我們未在理論上研究的更高秩、非嚴格鞍點,我們復制了 [19] 在矩陣補全任務上的一項實驗(參見其圖 1)。具體而言,網絡被訓練以擬合一個秩為 3 的矩陣,這意味著從接近原點開始,GD 訪問了 3 個鞍點(依次為秩 0、秩 1 和秩 2),最終收斂到一個秩為 3 的解,如圖 6 所示。
我們發現,當初始化接近 BP 訪問的任何鞍點時,PC 能夠快速逃離,并且沒有出現梯度消失的現象(見圖 6),這支持了平衡能量的所有鞍點都是嚴格鞍點的猜想。
5 討論
總結來說,我們邁出了一步,開始刻畫 PC 學習的有效景觀——即推理平衡狀態下的能量景觀。
對于深度線性網絡(DLNs),我們首先證明了平衡能量是一個權重依賴的重新縮放后的均方誤差(MSE)損失(定理 1)。這一結果糾正了文獻中先前的一個錯誤觀點:即 MSE 損失等于輸出能量 [34],并且總能量(公式 2)因此可以分解為損失和其余(隱藏層)能量(這種關系僅在前饋活動值時成立)。如我們在下文所述,公式 5 還為進一步研究 PC 的學習景觀提供了基礎。
接著,我們證明了 MSE 損失中的許多非嚴格鞍點,特別是零秩鞍點,在任何 DLN 的平衡能量中都變成了嚴格鞍點(定理 2 和定理 3)。這些鞍點包括原點,使得 PC 對梯度消失現象更加魯棒(見圖 6 和圖 11)。我們通過在線性和非線性架構上的實驗徹底驗證了我們的理論,并為平衡能量中更高秩鞍點的嚴格性提供了實證支持。基于這些結果,我們猜想平衡能量的所有鞍點都是嚴格的。總體而言,PC 的推理過程可以被解釋為使損失景觀變得更加良性。
5.1 含義(意義與影響)
我們的工作在解釋力和預測能力方面,顯著超越了現有的關于預測編碼(PC)的理論。大多數先前的研究都基于非標準假設或粗略近似,導致實驗預測不夠具體。例如,文獻 [3] 將 PC 解釋為一種隱式的梯度下降(implicit GD),但該結論僅適用于小批量數據,并且依賴于對活動值和參數學習率的逐層重新縮放。([2] 擴展了這一結果,去除了活動值的重新縮放,但并未去除學習率的縮放。)
相比之下,我們的理論唯一的主要假設是線性,而且我們通過實驗證明了所有結果在非線性網絡中也成立。同樣地,[2] 和 [18] 都通過對能量函數進行二階近似來論證 PC 利用了 Hessian 信息。然而,我們的研究清楚地表明,PC 可以利用更高階的信息,將高度退化、H 階的鞍點轉化為嚴格鞍點。
先前的理論也難以解釋為何在不同的任務、模型和優化器下,PC 的學習加速并不總是能被觀察到 [3, 54]。我們的景觀分析雖然還不完整(詳見下文),但承認了這些因素及其相互作用,有助于解釋不一致的發現,并預測何時可以期待、何時無法期待加速效果。在其他條件相同的情況下,PC 在深層且窄 的網絡上應該收斂得更快(盡管如我們在下文中討論的那樣,不能太深),因為原點鞍點與標準初始化之間的距離隨著網絡寬度增加而增大 [39]。這很可能解釋了 [54] 在一個較窄(每層 n? = 64)的 15 層全連接網絡中所報告的學習加速現象。然而,在實踐中,其他條件往往并不相等,從未能達到推理平衡,到不同數據集、架構和優化器之間的相互作用都會影響最終的收斂表現。這就引出了一個問題:最小化平衡能量是否比最小化損失函數更快或性能更好? 我們將在后文再次回到這個問題。
更廣泛地說,我們的景觀理論與 [56] 的工作密切相關,他們展示了在線性物理系統中使用平衡傳播(equilibrium propagation)[49, 50] 進行學習會對活動值(而非權重)的 Hessian 矩陣產生有益影響。探索這些聯系——以及更一般意義上推理過程對于能量型系統學習的好處——可能是一個有趣的研究方向。
我們的研究還對大腦中信用分配(credit assignment)的理論具有啟示意義。特別是,我們的結果為最近提出的“前瞻性配置”(prospective configuration)原理 [54] 提供了更為堅實的理論基礎,表明 PC 的推理確實可以通過利用高階信息來促進學習。與此同時,我們的研究也表明,有關 PC 能普遍加快學習速度的說法可能被夸大了 [54]。
5.2 局限性
最后,我們討論本工作的主要局限性。
首先,根據我們的推導,所研究的能量鞍點的“嚴格性”成立的前提是處于精確的推理平衡狀態 。我們注意到,即使沒有達到完全的平衡狀態,PC 也有可能改善損失鞍點的退化問題,在這個意義上,PC 可以被視為一種資源。然而在實踐中,PC 推理需要越來越多的迭代次數才能在更深的網絡上收斂,這與我們的景觀理論是一致的:因為隨著深度增加,損失鞍點變得越來越退化。因此,我們的結果突顯了一個根本性的挑戰:如果要在大規模任務中實現 PC 的學習優勢,就必須加快其在深層模型上的推理速度[40]。
即使這一挑戰被克服了,對于深度網絡的實際訓練來說,似乎還存在兩個相互關聯的關鍵問題。
第一個問題是:是否存在某些條件下,可以以更少的計算或內存開銷、更快地最小化平衡能量,并且至少獲得與傳統方法相當的性能 ?例如,像 Adam [24] 這樣的優化工具和殘差連接(skip connections)[17] 能夠幫助逃離原點鞍點,但代價是更高的內存消耗。這種代價是否可以與 PC 推理帶來的計算成本進行權衡?對 PC 推理成本進行更正式的刻畫將是朝這個方向邁出的有用一步。
第二個問題是:是否存在一些場景,盡管 PC 更慢或效率更低,但卻能顯著提升模型性能 ?這是一個難以回答的問題,因為我們距離建立一個完整的深度學習泛化理論還有很大差距 [63, 20]。不過,根據我們關于原點鞍點的結果(定理 2),值得注意的是,在那些低秩先驗有益的問題上(例如矩陣補全問題,見圖 6),使用小初始化的 GD 可以比標準初始化收斂到泛化能力更強的解 [19]。
最后,要全面理解 PC 的整體收斂行為,還需要進一步刻畫平衡能量的其他臨界點,尤其是它的極小值 [14]。我們的工作——特別是公式 5——為這一研究提供了可能。在附錄 §A.3.7 中,我們進行了初步探索,結果顯示對于線性鏈狀結構,平衡能量的全局極小值比 MSE 損失的極小值更平坦。這一結果或許可以解釋一個常見的觀察現象:PC 在訓練后期往往會逐漸減緩收斂速度,但我們將其全部意義留待未來研究。
A.1 通用符號與定義
A.2 相關工作 A.2.1 預測編碼的理論
預測編碼(PC)與反向傳播(BP)
[60] 是最早證明在多層感知機中,當輸入的影響相對于輸出被加強時,PC 可以近似 BP 的研究。[36] 將這一結果推廣到了任意計算圖,包括卷積神經網絡和循環神經網絡,前提是滿足所謂的“固定預測假設”(fixed prediction assumption)。
后來有研究提出了一種變體形式的 PC:如果權重在精確的時間點進行更新,則它可以在多層感知機上完全等價地計算出與 BP 相同的梯度 [53],這一結果隨后被 [46] 和 [42] 進一步推廣。[33] 從能量建模的角度統一了這些以及其他關于 PC 近似 BP 的結果。[62] 證明了所有這些 PC 變體的時間復雜度都不低于 BP。
預測編碼與其他算法[13] 對近似 BP 的各種 PC 變體在推理階段的收斂性進行了深入的動力系統分析。[34] 表明,在線性網絡中,PC 推理平衡狀態可以被解釋為 BP 前饋傳遞值與通過目標傳播(target propagation)計算出的局部目標之間的平均值。[54] 提出,PC 和其他基于能量的算法實現了一種根本不同的信用分配機制,稱為“前瞻性配置”(prospective configuration),即神經元首先改變其活動以與目標對齊,然后更新權重以鞏固這種活動模式。對于批量大小為 1 的情況,[3] 證明了在特定的逐層重新縮放條件下(包括對活動值和參數學習率的縮放),PC 可以近似隱式梯度下降(implicit gradient descent)。[2] 進一步表明,當該近似成立時,PC 可能對 Hessian 矩陣的信息敏感。類似地,最近的研究將 PC 描述為一種二階信任域方法(second-order trust-region method)[18]。
A.2.2 鞍點與神經網絡
在這里,我們回顧一些關于以下兩個方面的相關理論和實證研究:(i) 神經網絡損失景觀中的鞍點;(ii) 不同學習算法(尤其是 (S)GD)在鞍點附近的性能表現。關于神經網絡損失景觀與優化問題的更一般性綜述,可參見 [57] 和 [58]。
神經網絡損失函數中的鞍點 這項研究始于 [5] 的工作,他們表明對于具有一個隱藏層的線性網絡,MSE 損失的所有臨界點要么是全局最小值,要么是嚴格鞍點(定義 1)。同樣針對該模型,[48] 后續展示了在小初始化條件下存在從一個鞍點到另一個鞍點的學習過渡,并在特定數據假設下刻畫了梯度下降(GD)的動力學行為。[11] 強調了在高維非凸的神經網絡損失中,鞍點相對于局部極小值更為普遍。特別是,他們通過實驗演示了網絡損失景觀與隨機高斯誤差函數之間在定性上的相似性:一個臨界點所關聯的誤差越高,它就越可能是鞍點,且這種可能性呈指數增長趨勢 [8]。
[23] 著名地將 [5] 的結果推廣到了任意深度的線性網絡(DLNs),證明在對數據做出一些弱假設的前提下,所有的局部極小值都是全局極小值。這一結果隨后被 [29] 在更寬松的假設下進行了簡化與擴展。重要的是,[23] 第一次表明:對于只有一個隱藏層(H = 1)的神經網絡,所有鞍點都是嚴格鞍點(或稱為一階鞍點),而更深的網絡則包含非嚴格鞍點(H 階鞍點),例如權重全部為零的原點。此后,一系列變體和擴展的研究成果相繼出現 [61, 64, 25, 65, 37, 66]。對于我們當前的目的,一個重要擴展是由 [1] 提出的,他們對 DLN 的 MSE 損失的所有臨界點進行了二階刻畫,包括嚴格鞍點和非嚴格鞍點。
靠近鞍點時的學習行為 這項研究可以追溯到 [15],他們表明帶有噪聲的 SGD 可以在多項式時間內收斂于嚴格鞍點函數。[27] 證明了一個類似的結果:在幾乎所有的初始化條件下,不帶任何噪聲的 GD 最終會逃離嚴格鞍點。這一結論后來被推廣到其他一階方法 [26]。[21] 證明了另一種帶有噪聲的 GD 版本可以在與維度相關的對數多項式時間內以高概率收斂到二階臨界點。關于 GD 及其變體的這些以及其他收斂性結果的綜述,可參見 [22]。[4] 表明:(i) 存在一種 GD 的進一步變體可以被證明收斂到三階臨界點并逃離二階鞍點,但計算成本很高;(ii) 尋找更高階的臨界點是 NP-難問題。 [12] 證明了一個重要的結果:雖然標準 GD 在常見初始化條件下最終會逃離嚴格鞍點,但它可能需要 指數級的時間 才能做到這一點。這與前面提到的帶擾動的 GD 版本形成對比,后者可在多項式時間內收斂。同樣地,[51] 證明了在線性鏈狀結構或寬度為 1 的一維網絡中,GD 的收斂時間隨深度呈指數增長。[39] 分析了類似的模型,并表明除非適當調整寬度,否則梯度和曲率都會隨著網絡深度增加而消失。[39] 認為這在一定程度上解釋了自適應梯度優化器(如 Adam [24])的成功之處,因為它們能夠適應平坦的曲率區域。類似地,[55] 表明自適應方法可以通過在臨界點附近重新縮放梯度噪聲使其各向同性,從而更快地逃離鞍點。
[19] 提出了一個鞍點到鞍點的動力學猜想,其中梯度下降(GD)會依次經過一系列 秩逐漸增加 的鞍點,然后最終收斂到一個 稀疏的全局最小值 。一些研究也表明,在實際中,隨機梯度下降(SGD)可能會收斂到 二階臨界點 ,這些臨界點是 非嚴格鞍點 而非最小值 [47, 7]。
A.3 證明與推導
A.3.1 深度線性網絡(DLNs)的損失 Hessian 矩陣
這里我們推導了均方誤差(MSE)損失(公式 1)關于任意深度線性網絡(DLNs)權重的 Hessian 矩陣(見 §2.1);這基本上是對 [52] 中結果的重新推導,只是采用了略有不同的記號。2 接著,我們展示了在原點處(θ = 0),Hessian 矩陣及其特征譜如何隨著隱藏層個數 H 的變化而變化。我們從給定權重矩陣的損失梯度開始推導。
對于單隱藏層網絡,Hessian 矩陣是 不定的 (indefinite),其正負特征值由經驗輸入-輸出協方差給出,如 [48] 所述。對于任何具有多個隱藏層的深度線性網絡(DLN),Hessian 矩陣為零,因此原點是一個 二階臨界點 。在一般情況下,該點是一個 非嚴格鞍點 ,因為損失函數的某些依賴于網絡深度的高階導數中將至少存在一個負的方向(逃逸方向)。更具體地說,對于一個具有 L 層的網絡,所有階數小于 L 的導數都為零,而負的方向將出現在階數大于等于 L 的導數中。
A.3.2 深度線性網絡(DLNs)的平衡能量
其中目標被建模為一個高斯分布,其均值由網絡函數給出,協方差為 Σ 。相比之下,在預測編碼(PC)網絡中,不僅僅是輸出層, 每一個隱藏層的激活值 都被建模為高斯分布(見 §2.2)。
我們現在可以通過對公式 27 關于 y 取期望和方差,來推導出目標的隱式生成模型。
A.3.3 深度線性網絡(DLNs)平衡能量的 Hessian 矩陣
在此,我們推導深度線性網絡(DLNs)平衡能量在原點處的 Hessian 矩陣,其推導過程與損失函數 Hessian 的計算類似(見 §A.3.1)。附錄 A.3.5 展示了針對一維線性網絡的等效推導,該推導保留了所有關鍵直覺且更易于理解。我們從先前為 DLNs 推導出的平衡能量開始(見 §A.3.2,公式 29),結果表明該能量實際上是以下經過重新縮放的 MSE 損失:
計算 Hessian 矩陣需要多次應用乘積法則,因此為了簡化分析,我們分別考察每個項(公式 31 和 32)在原點處的導數貢獻。由于第一項僅僅是損失函數的一個重新縮放形式,并且根據公式 33,它關于同一權重矩陣在零處的二階導數恒為零。
B 的二階導數需要五次應用乘積法則,涉及殘差的一階導數(及其轉置),以及重新縮放項的一階和二階導數。如上所示(公式 34),重新縮放在原點處的一階導數為零,并且對于任何具有一個或多個隱藏層的網絡,殘差關于任意權重矩陣在零處的導數恒為零,即 。然而,對于最后一個權重矩陣關于其自身的特殊情況,重新縮放的二階導數不為零:
我們看到,與損失函數的 Hessian 矩陣(公式 20)相比,能量函數的 Hessian 矩陣對于任意 H 都具有一個非零的最后一個對角塊。我們注意到——但并未深入探討——它可能與目標傳播(target propagation)[30, 34] 存在聯系。單隱藏層的情形將在下一節中完整推導(§A.3.4)。可以很容易地證明,這些矩陣具有負的特征值:
A.3.4 示例:單隱藏層線性網絡
在此,我們展示一個示例計算,比較具有單個隱藏層(H=1)的深度線性網絡(DLNs)在原點處的損失函數和平衡能量的 Hessian 矩陣。在這種情況下,MSE 損失和平衡能量分別為:
A.3.5 線性鏈結構的平衡能量 Hessian 矩陣
在此,我們推導了線性鏈(即寬度為 1 的網絡)或單位寬度網絡 wL:1x(其中 n0=?=nL=1)的平衡能量的 Hessian 矩陣(以及其在原點處的特征結構)。這一推導與寬網絡情況下的推導(見 §A.3.3)類似,但它展示了所有關鍵見解,且更易于理解。在標量情形下,由預測編碼(PC)定義的目標的隱式生成模型為(見 §A.3.2):
A.3.7 平衡能量的更平坦全局極小值(線性鏈情形)
在此,我們對平衡能量與 MSE 損失的極小值進行初步比較研究。對于線性鏈結構(見 §A.3.5),我們表明,平衡能量的全局極小值比 MSE 損失的極小值更加平坦。更準確地說,平衡能量的全局極小值實際上是損失函數極小值的縮放版本,其縮放因子與平衡能量本身的縮放因子相同(見 §A.3.2)。這一結論推廣了 [18] 中關于單個隱藏單元線性鏈的結果。
A.4 實驗細節
用于復現實驗的代碼可在 https://github.com/francesco-innocenti/pc-saddles 獲取。除非另有說明,所有預測編碼(PC)網絡均使用標準的歐拉積分方法進行推理動力學模擬至達到平衡狀態(見 §2.2,公式 3),積分步長為dt=0.1,迭代次數取決于具體問題。
原文鏈接: https://arxiv.org/pdf/2408.11979
特別聲明:以上內容(如有圖片或視頻亦包括在內)為自媒體平臺“網易號”用戶上傳并發布,本平臺僅提供信息存儲服務。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.