節食式持續學習:在計算受限下從稀疏標注數據流中學習
CONTINUAL LEARNING ON A DIET:LEARNING FROM SPARSELY LABELED STREAMS UN-DER CONSTRAINED COMPUTATION
https://arxiv.org/pdf/2404.12766
摘要
我們提出并研究了一種更貼近現實的持續學習(Continual Learning, CL)設定:在訓練過程中,學習算法在每個時間步被限制使用有限的計算資源。我們將這種設定應用于大規模、標簽稀疏的半監督持續學習場景中。以往表現良好的CL方法在這種具有挑戰性的設定下表現非常糟糕。過擬合于稀疏的有標簽數據以及受限的計算資源是導致性能不佳的兩個主要原因。我們的新設定鼓勵學習方法在訓練過程中有效地、高效地利用無標簽數據。為此,我們提出了一種簡單但高度有效的基線方法——DietCL,該方法聯合利用了無標簽和有標簽數據。DietCL為這兩類數據精心分配計算資源。我們在多個大規模數據集上驗證了該基線方法的有效性,例如CLOC、ImageNet10K和CGLM,在受限預算設置下,DietCL大幅優于所有現有的監督式CL算法以及近期提出的持續半監督學習方法。我們的廣泛分析和消融實驗表明,DietCL在各種標簽稀疏程度、計算預算以及其他多種變量條件下都具有良好的穩定性。我們的代碼已公開:https://github.com/wx-zhang/continual-learning-on-a-diet
1 引言
在這個信息豐富的時代,數據并不是一次性全部呈現出來的,而是從一個非平穩的環境中逐步流式輸入。例如,像YouTube、Snapchat和Facebook這樣的社交媒體平臺每天都會接收到海量的數據。這些數據的內容及其分布很大程度上取決于社會趨勢和平臺上關注的焦點,因此會隨著時間而變化。例如,Snapchat在2017年報告稱,每天從全球用戶那里接收超過35億個短視頻(Snap,2017)。這些視頻必須立即用于各種任務處理,包括圖像評分、推薦系統、仇恨言論檢測和虛假信息識別等。
持續學習旨在應對這類挑戰,其核心目標是設計能夠在接納新數據流的同時保留已有知識的訓練算法。目前已有多種解決方案被提出,包括基于正則化的方法(Kirkpatrick 等,2017)、基于模型結構的方法(Ebrahimi 等,2020),以及基于記憶回放的方法(Chaudhry 等,2019b)。
然而,在大多數實際應用中,數據量極其龐大,并需要實時處理。這種需求對持續學習算法提出了預算限制。為了更好地說明這些限制,我們可以設想這樣一個場景:如果一個學習算法需要10天來學習一個包含35億樣本的數據集(這是Snapchat一天內積累的數據量),那么在訓練期間,新的數據流又會產生350億個新樣本。這將導致模型在部署時已經嚴重滯后于當前的數據分布。這種處理時間上的約束不僅限制了可使用的標簽數量,也對算法的計算時間提出了嚴格要求。
現有文獻已經認識到這一問題,并努力尋找解決方案。盡管一些在線持續學習方法嘗試對在線數據流進行建模(Chaudhry 等,2019a;Cai 等,2021),但它們更注重以批次方式進行訓練和評估,常常忽略了對算法中計算和時間開銷的約束。此外,大多數在線持續學習工作都假設可以獲取完整的標簽集合。
另一方面,離線持續學習領域也取得了一些進展,通過引入無標簽數據來提升性能(Pham 等,2021;Fini 等,2022)。這些方法通常需要對無標簽數據進行完整遍歷,因此忽視了處理大量無標簽數據所帶來的高昂計算需求。直到最近,才有一些研究開始關注受限預算下的持續學習 (例如:Prabhu 等,2023a),旨在限制每個任務的計算資源消耗,從而使得持續學習算法能夠在上述現實場景下得以應用。然而,這些研究仍然集中在完全有標簽的數據流 上。
在本研究中,我們將受限預算下的持續學習擴展為一種半監督的形式 。該場景具有計算資源受限 和標簽稀疏 的特點,因此我們將其稱為“節食型持續學習”(CL on Diet)。這一挑戰的核心如圖1所示:在每一個時間步中都會涌現出大量的數據,但其中只有少量數據附帶有標簽。隨后,學習算法必須在接收數據的時間間隔內,在有限的計算資源約束下對這些數據進行訓練。
在“CL on Diet”設定下,我們首先研究了現有方法應對這種挑戰性設置的能力。確實,我們的研究發現,當前僅依賴有標簽數據的方法容易對有限的標簽產生過擬合。而那些利用無標簽數據的方法則往往需要大量的計算資源來進行多次完整遍歷大規模無標簽數據集,導致在受到計算限制時性能顯著下降。
為了解決這些問題,我們提出了一個高效的基線方法——DietCL。該方法引入了一種計算預算分配機制 ,用于協調當前與過去數據分布的學習。此外,我們還提出了一種統一的訓練策略 ,能夠同時融合來自有標簽和無標簽數據的信息。我們在多個主流的大規模持續學習數據集(包括ImageNet10K、CLOC和CGLM)上驗證了該基線方法的有效性。結果顯示,我們的方法在這一現實場景中達到了最先進的性能。此外,我們還展示了該基線方法在不同數據流長度、計算限制以及標簽率條件下的魯棒性。
我們的貢獻可以總結為以下三個方面:
我們提出了一種具有挑戰性的大規模半監督持續學習設定,命名為“CL on Diet”,其特點是在標簽稀疏且計算資源受限的情況下進行學習。我們探討了現有方法在這一設定下的問題。
我們提出了DietCL,這是一種高效利用計算預算、聯合優化有標簽和無標簽數據的持續學習基線方法。
我們在大規模數據集上進行了廣泛的實驗,包括 ImageNet10K、CLOC 和 CGLM,涵蓋了數據增量和類別增量兩種設置。實驗結果表明,我們提出的簡單方法 DietCL 在標簽稀疏的數據流中,其性能優于現有的監督式和半監督式持續學習算法,提升幅度為 2% 到 4%。我們還展示了該基線方法在不同數據流長度、標簽比例和計算預算條件下均表現出優異的性能。
2 相關工作
半監督持續學習(Semi-supervised Continual Learning)
隨著自監督學習在預訓練任務中的成功,越來越多的研究開始探索其在持續學習中的應用。Caccia & Pineau(2022)表明,在持續學習的元學習階段,自監督損失函數的表現優于監督損失函數。Fini 等人(2022)在離線持續學習中對每個任務進行了自監督預訓練。Gomez-Villa 等人(2022)、Pham 等人(2021)以及 Boschini 等人(2022)使用無標簽數據進行知識蒸餾或作為正則化損失,以緩解模型遺忘問題。
大多數方法在其驗證中使用的是中小型數據集,并且有適度的有標簽與無標簽數據劃分比例。這些方法通常不考慮計算開銷,并且在計算資源分配上往往對有標簽和無標簽樣本一視同仁。當我們嘗試將這些方法擴展到處理現實世界中大規模的無標簽數據和稀疏標簽時,正如我們在第3.2節所展示的那樣,由于這些方法分配給有標簽數據的計算資源有限,導致它們難以學習到有意義的與標簽相關的信息。此外,由于整體計算資源受限,這些方法也難以從無標簽數據中有效學習。
持續學習中的場景設定與預算約束
傳統的持續學習主要關注任務增量學習、類別增量學習和領域增量學習。由于現實世界中數據流的多樣性,已有大量研究致力于消除持續學習中的一些特定限制條件(Wang 等,2024),從而使其更具實用性。
一些研究考慮放寬任務邊界限制,探索了任務邊界未知(Aljundi 等,2019)或任務邊界模糊(Koh 等,2021)的情況。另一些研究則考慮數據是實時到達的,任務以單次通過的數據流形式呈現(Chaudhry 等,2019a)。這些研究大多從數據到達的形式出發來設計算法。
然而,在本文中我們關注的是由處理主體與數據流之間關系 所帶來的約束。直到最近,才有人提出關于持續學習中訓練時間限制的問題。研究表明,如果每個任務可以無限訓練,非持續學習的算法也能達到與持續學習算法相當的效果(Prabhu 等,2020;Ghunaim 等,2023)。盡管某些在線持續學習方法(如 Koh 等,2021)報告了基于固定更新次數的性能表現,但這主要是為了公平比較,并未深入探討訓練預算對算法的影響。
近期的一項研究(Prabhu 等,2023a)展示了在有限預算下離線持續學習的有效性,結果表明當預算不足時,從均衡分布中學習是有幫助的。受此啟發,我們提出應在半監督持續學習 中引入訓練時間上的約束。然而,在我們的工作中,預算限制設置還結合了數據流中標簽稀疏的設定,這構成了一個新穎但更具挑戰性、更貼近現實的問題。
3 節食型持續學習(Continual Learning on a Diet)
3.1 問題定義
在受限預算下的半監督持續學習 設置中,我們的目標是學習一個由參數 θ 參數化的函數 fθ : X → Y,該函數將圖像 x ∈ X 映射到類別標簽 y ∈ Y。
在每一個時間步 t,數據流會采樣 nt 張圖像 {xti}nti=1 ~ Xt,然后僅向 fθ 提供其中 ntl 個樣本的標簽。與以往研究不同的是,在這種設定下,持續學習算法需要在每個時間步內受限的計算預算 條件下更新參數 θ,使得 fθ 在所有已見過的數據分布上都能表現良好。
在本文中,我們將計算預算 定義為以“前向-反向傳播次數”為單位進行歸一化的總 FLOPs(浮點運算次數)。也就是說,它對應于給定批次大小下的訓練迭代次數。該計算預算涵蓋了訓練所需的所有運算操作,包括從前向-反向傳播更新模型參數,到其他任何操作(例如 Aljundi 等人(2018)中使用的樣本重要性權重等)。
我們遵循 Prabhu 等人(2023a)在受限預算持續學習中的存儲假設:緩沖區足夠大以存儲所有的有標簽數據,但每次使用數據的數量要根據計算預算加以限制。
3.2 改進機會
大多數現有的監督式和半監督式持續學習方法都假設擁有充足的計算資源和標簽數據。然而,在某些現實場景中,例如本文提出的“節食型”(diet)場景中,這些假設并不總是成立。在本節中,我們通過在 20 分割的 ImageNet10K 數據集上進行實驗,研究了現有監督方法 ER 和半監督方法 CaSSLe 在不同時間步計算預算下的表現,如圖 2 所示。我們將它們的表現與我們在第 4 節中提出的算法 DietCL 進行對比。圖中的每一個點表示:給定相應的時間步預算下,在持續學習流結束時的平均準確率(按照 Chaudhry 等人(2019a)的方式計算)。
監督式 CL 在低預算下的表現如何?
如圖 2 左側區域 (a) 所示,當每個時間步的計算預算低于 400 時,ER 的性能隨著預算減少而顯著下降。這在一定程度上是由于 ER 中存在“穩定性差距”(stability gap)(Lange 等,2023):在學習新任務的過程中,模型首先過擬合于新數據,然后才通過重放舊任務的數據來恢復知識。關于這一現象的更多細節,請參見附錄 C.1。在低預算場景下,下一個時間步的訓練可能在前一個知識恢復過程完成之前就開始了。此外,有限的可用標簽數量也可能導致模型僅捕捉到當前任務的一個狹窄分布。我們的半監督持續學習方法可以通過使用無標簽數據作為正則化項,有效消除過擬合和穩定性差距問題,從而提升持續學習的表現。
監督式 CL 是否能充分利用可用的計算預算?
監督式 CL 不僅在有限預算下難以高效學習,而且在計算預算較大時還會出現嚴重的過擬合問題,如圖 2 中區域 (b) 所示(當預算超過 400 時)。然而,在半監督方法中,這種過擬合現象并不那么嚴重。這激勵我們應將冗余的計算資源有效地分配給無標簽數據。
無標簽數據是否必要?
我們證明了即使計算預算尚未達到監督學習算法的最大需求(例如在圖 2 左側區域 (a),當預算小于 400 時),我們也可以通過利用無標簽數據來改善泛化能力,從而獲得更好的結果。此外,當預算遠遠不足時(例如圖 2 左側區域 (b),當預算大于 400 時),純監督式的持續學習算法可能無法充分利用計算預算。在這種情況下,將多余的預算用于無標簽數據有助于更好地捕捉當前任務的數據分布。
CaSSLe 面臨的挑戰
從無標簽數據中學習可能是計算密集型的。以 CaSSLe 為例,該方法最初設計為每個任務訓練 500 個 epoch。我們的實驗表明,在大規模數據集上,當計算預算受限時,CaSSLe 的表現會急劇下降。如圖 2 右側所示,即使我們將預算增加到 2500 步(約 3 個 epoch),CaSSLe 的準確率依然非常低。相比之下,DietCL 僅需大約 500 步即可收斂。
4 提出的解決方案:DietCL
我們現在介紹我們的方法,該方法聯合利用有標簽數據和無標簽數據,以高效使用計算預算并捕捉不斷變化的數據分布。在整個訓練過程中,每個任務的學習都受限于總預算 B = Bu + Bl ,其中 Bu 是用于無標簽數據的預算,Bl 是用于有標簽數據的預算。
學習分布內部關系(Learn In-Distribution Relationship)
在大規模數據流中,每一個時間步對應一個不同的數據分布,這就需要付出大量努力去學習這些分布內的特征關系。為了實現這一點,我們通過自監督學習(SSL)來利用當前分布中的無標簽數據。接下來我們將討論如何在持續學習中整合 SSL 以及應選擇哪種類型的 SSL 方法。
如圖 2(右)所示,在持續學習中進行 SSL 預訓練可能是非常耗費計算資源的。進一步在附錄 C.2 中分析后我們發現,僅通過無監督損失學習到的特征演化非常緩慢,并且不面向標簽方向。為了解決這個問題,我們對當前的無標簽數據和有標簽數據進行聯合學習,其中密集的無標簽數據作為正則化項,防止模型過擬合于稀疏的有標簽數據。
在具體的 SSL 算法選擇上,對比學習(contrastive learning)和掩碼建模(masked modeling)是兩種主流方法。對比學習通常需要將輸入圖像增強為兩個視圖,并為兩個不同的主干網絡更新梯度。因此,在給定 Bu 的無標簽數據預算下,只能進行 Bu/2 次梯度更新,導致預算嚴重未充分利用。
因此,我們采用了一種高效的掩碼建模方法——MAE(He 等,2022),通過重建輸入圖像的掩碼塊來捕捉當前分布。我們在編碼器上添加了一個重建頭,記作 fθr : Z → X ,它將編碼后的特征映射回圖像空間。在每一個時間步 t,我們使用無標簽數據流樣本來計算重建損失,公式如下:
其中,算子 ψ? 從每個無標簽樣本 x? 的掩碼塊集合 I? 中提取第 p 個圖像塊。
此外,正如 Caccia 等人(2022)所提出的那樣,持續學習可以從單獨學習當前分布中受益。因此,我們對當前時間步中未出現的類別對應的 logits 進行掩碼處理,并在當前分布的有標簽數據上計算以下損失函數:
其中,I? 表示一個掩碼函數,它將當前時間步 t 中未引入的類別所對應的索引全部置零,CE 表示交叉熵損失(cross entropy loss)。
為了緩解遺忘問題,我們進一步維護了一個任務均衡的緩存區 M ,該緩存區僅包含當前及之前時間步的有標簽數據。該緩存區上的損失函數可以表示如下:
預算分配(Budget allocation)
在我們最終的損失函數(公式4)中,模型是聯合地在當前時間步的無標簽數據(Lr)、當前時間步的有標簽數據(Lm),以及來自均衡緩存區的有標簽數據(Lb)上進行訓練的。
如圖2所示,經典方法隨著預算增加會經歷兩個階段:預算較低時的學習階段(區域a),以及預算較高時的過擬合階段(區域b)。第二階段出現過擬合的主要原因是過度使用了當前任務中非常稀疏的有標簽數據。因此,我們根據總預算,為每種數據來源分配不同的訓練預算。
當總預算低于某個閾值 B 時,我們將預算平均分配給有標簽數據(Lm)、無標簽數據(Lr)和緩存區數據(Lb)。在這個預算范圍內,我們的算法可以快速收斂,延長訓練時間對當前類別的性能提升不明顯。
因此,我們將額外的預算僅用于緩存區數據,并使用公式3作為損失函數。這一階段主要用于平衡當前類別與之前類別的學習;詳細分析請參見附錄 C.3。
在實際應用中,我們通過交叉驗證選擇閾值 B ;更多細節請參見附錄 B.3。整體訓練流程和語義代碼請參考附錄 A。
5 實驗
在本節中,我們在多個大規模數據集上進行了實驗,使用我們提出的“節食型持續學習”(CL on Diet)設置。在該設置下,數據流僅部分帶有標簽,且算法在每個時間步被授予有限的計算預算。我們首先介紹實驗設置,然后將我們提出的 DietCL 與其他多種方法進行比較。最后,我們展示了所提出方法在不同標簽率、計算預算和數據流長度下的魯棒性。
5.1 實驗設置
基準與評估指標
我們采用了三個大規模持續學習數據集:ImageNet10K、CLOC 和 CGLM,并按照 Prabhu 等人(2023a;b)的方式對 DietCL 及其他方法的性能進行評估。我們將 ImageNet10K 劃分為 20 個類別增量任務,而 CLOC 和 CGLM 則根據每張圖像元信息中的上傳時間劃分為 20 個任務。所有數據集的詳細統計信息請參見附錄 B.1。
我們報告最后一個時間步的準確率以及整個數據流上的平均準確率。具體來說,設:
訓練設置
在所有實驗中,我們使用 ViT 模型(Dosovitskiy 等,2020)用于分類任務,使用 MAE 解碼器(He 等,2022)用于重建任務。兩個模型都在 ImageNet1K 數據集上進行了預訓練,并由 He 等人(2022)發布。我們將基礎學習率設置為 10?4,基礎批次大小為 256,并根據有效批次大小線性縮放學習率。我們使用多塊 NVIDIA A100 GPU 進行訓練,每塊設備上的批次大小為 256。當累計批次大小達到 1024 時,我們執行一次損失累積和反向傳播步驟。所有其他學習參數均來自 He 等人(2022)的研究。
基線方法
我們在設定中與監督式和半監督式的持續學習方法進行了比較。監督式方法包括:
- ER (Chaudhry 等,2019b;Prabhu 等,2023a)
- EWC (Kirkpatrick 等,2017)
- MAS (Aljundi 等,2018)
- GDumb (Prabhu 等,2020)
- L2P Wang 等,2022)
半監督方法包括:
- CaSSLe (Fini 等,2022)
- DualNet (Pham 等,2021)
- CCIC (Boschini 等,2022)
所有這些方法的實現細節請參見附錄 B.2。我們對 ER 使用了均衡采樣策略,有關采樣的討論請參見附錄 B.2。我們還將我們的方法與“聯合訓練”(joint training)進行了對比,即在整個數據集上進行預訓練和微調,用以展示性能上限,參考 Hu 等人(2021)。所有基線方法的每個時間步計算預算相同。
5.2 主要結果
我們在以下設置下進行了主要對比實驗:
- ImageNet10K(20 分割) :標簽率為 1%,計算步數為 500 步;
- CLOC(20 分割) :標簽率為 0.5%,計算步數為 1000 步;
- CGLM(20 分割) :標簽率為 5%,計算步數為 600 步。
圖 3 和表 1 展示了 DietCL 及第 5.1 節中描述的基線方法的實驗結果。圖中展示了每個時間步的評估準確率 A(t),表格中則列出了最后一個時間步的準確率 A(T) 以及平均性能 A。
DietCL 與監督式 CL 的比較
圖 3 的第一行展示了 DietCL 與監督式持續學習方法之間的比較。DietCL 在所有數據集(ImageNet10K、CLOC 和 CGLM)上始終優于我們所比較的所有監督式方法。
在這些監督式方法中,表現最接近的是 Replay(ER),它通過聯合使用當前有標簽數據和從記憶庫中均勻采樣的歷史標簽數據進行訓練。然而,DietCL 中對無標簽數據的引入顯著提升了性能,在各數據集上的準確率分別達到了 16.82%、5.98% 和 24.34%(見表 1)。
這表明,即使在每一步計算資源受限的情況下,DietCL 仍能有效利用無標簽數據來提升數據流中的學習效果。
DietCL 與半監督 CL 的比較
圖 3 的第二行展示了 DietCL 與近期提出的半監督持續學習方法以及 SSL 上限的對比。這些半監督方法最初并未考慮計算預算限制,在固定計算預算下其性能大幅下降。
如表 1 所示,其中表現最好的方法 CaSSLe 在 ImageNet10K 數據集最后一個任務上的平均準確率僅為 5.78%,而我們的方法達到了 16.82%。CaSSLe 最初設計為在其基準測試中對無標簽數據進行 500 輪訓練,這在我們的設定下大約相當于 50,000 步,而我們僅提供了 500 步的有限預算,導致其性能急劇下降。
此外,在我們的大規模數據集上評估這些半監督基線時,我們發現一些依賴類別間關系的方法(如 DualNet 和 CCIC)難以應對如此多類別的挑戰。
同時,如圖 3 所示,我們在三個數據集上的 fine-tuning 表現比其他半監督方法更接近“聯合訓練”(joint training)的性能上限。
學習趨勢分析
在 ImageNet10K 的類別增量基準中(約含 10,000 個類別),模型需要在每個時間步學習約 500 個新類別。這對大多數方法來說是一個巨大挑戰——既要學習大量新類別,又要記住之前學到的海量舊類別。
而在 CGLM 和 CLOC 的時間增量基準中,每個類別的數據量和內容會隨時間變化。這要求模型必須學習類別的本質特征,而不是過度擬合最近的訓練分布,以避免在識別先前任務時產生推理偏差。
各種方法的整體學習趨勢可以從圖 3 中曲線的斜率和表 1 中的 A 值看出。值得注意的是,圖中某些方法(如 EWC、MAS 和 CaSSLe)在前幾個任務中就出現了顯著的準確率下降,而 DietCL 避免了這種大幅性能下滑,并在 CLOC 和 CGLM 數據集上逐步提升了準確率。
在表中,我們的方法在所有三個數據集上仍然保持最高的平均性能 A,分別為 24.9%、4.98% 和 20.26%。我們認為這是由于我們在當前和歷史數據之間合理分配了預算,使得我們的方法能夠高效利用可用資源并適應新類別。
最后,結合我們在引言中提到的 Snapchat 示例,我們在附錄 D 中進一步將實驗結果與現實世界問題建立了聯系。
5.3 對公式 4 的消融實驗
我們在 1% 標簽率、20 分割的 ImageNet10K 基準上,設置了 500 步計算預算,對公式 4 進行了消融研究,結果如表 2 所示。
我們從一個 Replay 基線開始,逐步加入掩碼分類損失(masked classification loss)、均衡緩存區(balanced buffer)以及重建損失(reconstruction loss),觀察其各自對性能的影響:
加入掩碼分類損失后,最終準確率提升了 0.9% ;
引入均衡緩存區后,準確率進一步提升了 0.13% ;
在利用無標簽數據的幫助下,準確率又提升了 0.86% 。
這一結果表明,在孤立地學習當前分布的過程中,無論是通過掩碼損失還是利用無標簽數據,都對受限訓練預算的場景起到了最主要的作用。
我們在附錄 B.4 中還進行了其他組合順序的消融實驗。
5.4 計算預算與標簽率
我們在 ImageNet10K 數據集上研究了 DietCL 與最穩健的監督方法 GDumb 和表現最好的半監督方法 CaSSLe 在不同標簽率、計算預算和數據流長度下的穩定性。
改變計算預算 我們進行了實驗,分別設置每個時間步的計算預算為:100、500、1500 和 2500 次迭代。在以下所有實驗中,其他參數保持不變,即標簽率為 1%,共 20 個時間步。結果匯總如圖 4 所示。
當每個時間步的預算減少到 100 次迭代時,所有方法的性能都明顯下降。這一現象在監督式方法 GDumb 上尤為顯著,說明監督學習要達到可接受的性能所需的基本資源相對較多。然而值得注意的是,即使在如此嚴格的預算限制下,DietCL 仍優于另外兩種方法 ,并在各種預算條件下保持穩定的性能,包括每個時間步高達 2500 次迭代的情況。
雖然 2500 次迭代對于 DietCL 來說可能已足夠收斂,但相較于像 CaSSLe 這樣的半監督方法,它仍然是一個受限的設定。如前所述,CaSSLe 最初設計的訓練步數高達 50,000 步,是我們所測試最大步數的 20 倍。因此,期望 CaSSLe 在僅 5% 的原始預算內完成訓練顯然是不現實的。
改變標簽率 圖 5 展示了在數據流中標記樣本稀疏程度不同的情況下的實驗結果,即標簽率為 0.5%、1%、5% 和 10%。我們保持每個時間步的計算預算為 500 步,共 20 個時間步。
我們觀察到,當標簽率從 0.5% 提高到 1%,再提高到 5% 時,盡管計算預算未變,我們的方法平均準確率有明顯提升。然而,GDumb 和 CaSSLe 的性能提升并不顯著。此外,我們方法與其他方法之間的性能差距也在擴大,這表明我們在利用有標簽數據方面更加高效,這得益于無標簽數據的幫助。
當我們進一步將標簽率從 5% 提高到 10% 時(此時后續任務的數據流計算預算不足),并未觀察到明顯的性能提升。我們得出結論:當計算預算充足時,有標簽數據可以更好地指導無監督訓練,從而帶來更好的整體性能。
改變時間步數量 我們通過在 ImageNet10K 上進行實驗,研究了數據流長度(即任務數量)對模型性能的影響,實驗設置了 10、20、50 和 100 個時間步。這模擬了數據呈現速度的變化。這些實驗中我們使用的標簽率為 1%。
在不同任務數量的實驗中,我們通過按比例調整每個時間步的預算,使總的計算迭代次數與之前實驗一致(即等于 20 × 500)。結果顯示在圖 6 中。
在有相同數量的有標簽數據遍歷次數的情況下,DietCL 在整個任務序列上的平均準確率相近,分別為:24.4%、24.9%、23.7% 和 22.2%。然而,當任務長度增加時,GDumb 和 CaSSLe 的性能顯著下降。這表明,我們的方法相比以往方法更能應對小批量、高頻次的數據流場景。
6 結論
我們重新思考了現實世界中計算預算和標簽稀疏性的問題,這些問題在以往的持續學習研究中并未受到足夠重視。為了解決這一挑戰,我們設計了一種高效且有效的持續學習算法 DietCL ,該算法聯合使用有標簽和無標簽數據進行模型訓練。
我們在大規模基準數據集上對該設定進行了評估,包括 ImageNet10K、ImageNet2K 和 CGLM。我們的方法超越了經典方法,在平均準確率上達到了它們的兩倍。我們還在另外兩個具有挑戰性的基準數據集上展示了本方法的優越性能,并分析了計算預算、數據流長度以及標簽稀疏性對方法性能的影響。
我們相信,DietCL 可以作為探索有限計算預算與稀疏標簽條件下新型持續學習算法的一個起點 。
A 算法的語義代碼
整體訓練流程見算法 1。在 ModifyClassificationHead(修改分類頭) 過程中,我們根據當前已見過的類別總數擴展分類頭的最后一層,并使用之前學習到的權重來初始化已有的維度。在 SplitBudget(預算劃分) 過程中,我們將總預算劃分為聯合訓練階段和微調階段兩部分。
B 實驗更多細節
B.1 基準數據集統計信息
ImageNet10K:類別增量(class-incremental)
我們從 ImageNet21K V2 數據集(Ridnik 等,2021)中構建了一個大規模、標簽稀疏的 ImageNet10K 基準。為了消除對 ImageNet1K 的潛在偏差,尤其是在使用預訓練模型時,我們首先從 ImageNet21K V2 中移除了與 ImageNet1K(Deng 等,2009)中重復的類別。
最終得到的 ImageNet10K 基準包含 9459 個類別 ,總計 9,822,675 張有標簽圖像 。我們將該基準劃分為 T 個部分,代表一個持續學習流中的 T 個時間步。在每個劃分中,我們隨機選擇每個類別的固定比例作為獨立的有標簽數據。這種構建持續學習數據集的方式在之前的工作中已被廣泛采用(Chaudhry 等,2019a;Fini 等,2022)。
在每一個時間步中,由于我們假設數據流僅部分帶有標簽,因此我們只使用了每個時間步中 1% 的標簽數據 。
CLOC:領域增量(domain-incremental)
我們在 CLOC(Cai 等,2021)數據集上評估我們方法的領域適應能力。該基準包含 10788 個類別 ,總計約 3800 萬張用于地理定位任務的圖像 。圖像是按照拍攝時間的時間戳進行排序的,模擬了一種自然的數據分布偏移。
該數據集的數據流被劃分為 20 個時間步 ,在每個時間步中,只有 0.5% 的標簽可用 。也就是說,每個時間步大約揭示 190 萬張圖像 ,其中僅有 1000 張圖像 帶有標簽。
CGLM:領域增量(domain-incremental)
我們在 CGLM(Prabhu 等,2023a)數據集上也評估了我們方法的領域適應能力。該基準包含 10788 個類別 ,總計來自 Google 地圖的 581,100 張地標圖像 。這些圖像同樣按照拍攝時間的時間戳進行排序,模擬了自然分布偏移。
該數據集的數據流也被劃分為 20 個時間步 ,每個時間步中 5% 的標簽可用 。也就是說,每個時間步揭示了大約 3 萬張圖像 ,其中僅有 600 張圖像 帶有標簽。
在表 3 中,我們展示了 ImageNet10K 和 ImageNet2K 的統計數據,并將其與其他流行的半監督持續學習基準進行了比較。這兩個基準在規模和標簽稀疏性方面都遠遠超過之前的基準。
在表 4 中,我們展示了 CGLM 的統計數據,并將其與近期其他基于時間變化的領域增量半監督持續學習基準進行對比。CGLM 擁有更多的類別和更稀疏的標簽,因此相比 CLEAR10 和 CLEAR100 要困難得多。
GDumb(Prabhu 等,2020)我們按照原論文的標準實現進行操作。
掩碼被設置為當前時間步之前所見過的所有類別。
L2P(Wang 等,2022)
原始論文中使用了在 ImageNet21K 上預訓練、并在 ImageNet1K 上微調的模型權重。我們加載了 ImageNet1K 預訓練的權重,但發現性能突然下降。作者提出每個時間步訓練 5 個 epoch,且不使用回放緩沖區,也提供了有限緩沖區大小的情況。然而我們發現,在我們的設定下,如果修改模型使用無限大的回放緩沖區但限制梯度更新次數,性能反而更好。因此我們報告的是修改后的版本。CaSSLe(Fini 等,2022)
我們選擇了 Barlow Twins 作為 SSL 基線來報告結果。原始論文提出在 ImageNet100 上進行 400 個 epoch 的預訓練和 100 個 epoch 的微調。我們在設定中采用了 4:1 的預訓練與微調步數比例(總預算固定)。在微調階段,我們以類別增量的方式采用 CaSSLe,即為迄今為止所有已見類別訓練一個線性分類器,以實現公平比較。性能下降主要來源于類別增量分類器以及計算資源受限。DualNet(Pham 等,2021)
鑒于慢速網絡僅包含少量卷積層,我們只為快速網絡加載了 ImageNet1K 的預訓練權重。在原始論文的半監督實現中,作者采用了一種“變量采樣”方法,根據與標簽率的比較決定當前批次是使用監督損失還是無監督損失,這種方法顯著增強了有標簽數據的多樣性。在我們的實現中,我們首先將當前數據劃分為有標簽和無標簽兩個子集,然后采用原論文的變量采樣方法決定從哪個子集中采樣。我們將閾值設為 0.5,這給出了最好的結果,但仍低于原論文中的表現。此外,我們在計算中計入了所有的梯度步驟,包括對比損失所需的兩個視圖變換,同時固定總的梯度步數,這也導致了進一步的性能下降。盡管原論文提出了任務無關(task-agnostic)和無任務(task-free)的訓練策略,我們也嘗試了這兩種方式,但在評估時仍以類別增量方式進行。在這類大規模類別增量流、計算資源受限的設定下,兩種策略都沒有取得理想的結果。上界(Upper bound,He 等,2022)
特別地,SSL 聯合訓練(Joint Training)首先對所有無標簽數據進行一次自監督訓練,然后在所有可用的有標簽數據上進行微調。給定的總計算預算等于持續學習者的有效預算,即 20 個時間步 × 每個時間步的預算 。其中 80% 的預算用于自監督訓練,剩下的 20% 用于微調。
B.3 預算閾值
我們選擇通過交叉驗證來確定閾值 B ,以劃分“均衡訓練階段”和“緩存區學習階段”。交叉驗證僅在總共 20 個任務中的前 3 個任務上進行;這在持續學習領域是標準做法,并被廣泛接受(Chaudhry 等,2019a)。
為了展示這種選擇的魯棒性,我們在 ImageNet10K 基準上進行了不同監督預算的實驗,并在表 6 中比較了它們在第 2、3 和 4 個任務上的表現。我們展示了截至相應任務的平均準確率。
當從第 2、3 或第 4 個任務中選擇監督預算時,所選的預算閾值始終在 400 到 450 之間,表現出一致性。
B.4 其他順序下的消融實驗
在本節的消融研究中,我們展示了不同組合順序以及對應的觀察結果,并報告了最終模型的平均準確率。
第一張表格顯示,任務均衡緩存區(task-balanced buffer)可以對 Replay 方法帶來輕微的提升。然后,無論是 Lm(有標簽數據損失) 還是 Lr(重建損失) ,都能將算法性能提升大約 1%。這與我們最初的觀察結果一致。
第二張表格顯示,在沒有使用掩碼損失的情況下,無標簽數據只能帶來非常有限的性能提升。這進一步驗證了我們在第 4 節中提出的觀點:無標簽數據的損失必須由適當的監督損失來引導。
C 分析
C.1 在低計算預算下 ER 與 DietCL 的對比
在圖 7 中,我們展示了在總計算步數為 300 的情況下,DietCL 和 ER 在訓練第 2 個時間步時,對第 0、1 和 2 時間步引入類別的驗證損失和準確率的變化情況。
我們的結果顯示,經典持續學習算法 ER 出現了 Lange 等人(2023)所提出的“穩定性差距”(stability gap)現象。具體來說,在學習新類別過程中,之前已學類別的準確率首先會出現下降,直到新類別的學習趨于穩定后才會重新上升。
值得注意的是,在 ER 中,盡管當前類別的準確率最初可以上升到約 40%,但在后續時間步中又顯著下降,之后才逐漸回升。此外,當計算預算耗盡時,舊類別的學習仍未完成。
然而,在我們的算法中,如圖中藍色曲線所示,之前已學類別的準確率沒有出現明顯的下降,且新類別的學習過程更加平穩。換句話說,當前類別的準確率不會像 ER 那樣經歷“先升后降再升”的曲折過程,從而節省了大量的計算預算。
我們推測,這是因為在 ER 中,有標簽數據稀疏,模型更容易過擬合于當前類別稀疏數據所代表的空間。相反,我們在算法中引入了無標簽數據,引導學習過程朝向一個更具泛化能力的空間,從而消除了過擬合問題和穩定性差距。
因此,我們的算法在低預算場景下具有很高的學習效率,尤其是在標簽稀疏的情況下表現尤為突出。
C.2 在持續學習中直接進行預訓練與微調
我們對在持續學習中使用自監督學習(SSL)方法的有效性進行了實證研究。
一種簡單的方法是迭代地進行兩階段訓練,包括預訓練階段和微調階段,在預訓練階段利用無標簽數據。為了建立基線,我們實現了 OneStage 和 TwoStage 方法,這兩種方法都會回放當前時間步的所有有標簽數據。
在 OneStage 方法 中,我們從所有已見過的類別中隨機采樣樣本批次,并使用交叉熵損失進行分類。
在 TwoStage 方法 中,我們首先使用當前時間步的無標簽數據進行 MAE 預訓練,計算預算按照固定比例分配;然后使用剩余預算,從緩存區中隨機采樣有標簽數據對模型進行微調。
我們在一個標簽極其稀疏 但計算預算相對充足 的設定下進行了該實驗。請注意,與監督式訓練場景相比,這種設定已經非常有利于半監督訓練,因為預算較大,而有標簽數據極度稀疏。
C.3 微調階段的影響
在圖 2 中,我們利用最初的 400 個計算步 ,使用有標簽數據、無標簽數據以及來自均衡緩存區的數據對模型進行聯合訓練。對于超過 400 步的預算,剩余的計算步則被用于僅使用緩存區數據進行微調。
選擇 400 步作為分界點是基于觀察結果:傳統方法在超過這一閾值后往往會犧牲性能,這表明 400 步已足夠學習當前任務的數據分布。
以一個實驗為例,在 20 分割 ImageNet10K 數據集上,標簽率為 0.01、計算預算為 600 步的設定下,我們在圖 9 中比較了有無微調 的訓練結果。圖中顯示,在訓練的最后階段應用均衡微調(balanced fine-tuning)可以帶來更加均衡、也因此更優的性能表現。
D 與現實問題的聯系
我們旨在研究這樣一種場景:無法對所有新到達的數據進行完整的訓練輪次(full epoch) ,但可以為稀疏標注的數據分配足夠的計算資源。我們的算法討論了如何利用來自有標簽數據的多余預算,并將其用于無標簽數據的學習。
在我們的實驗中,在一塊 80G 的 NVIDIA A100 GPU 上,對 1% 標簽率、20 分割的 ImageNet10K 數據流 中的每個任務進行訓練至少需要 0.5 個 GPU 小時 。每個任務總共包含 100 萬張圖像 和 1 萬張有標簽圖像 ,我們以 1024 的批次大小 進行了 500 步梯度更新 。
現在,我們將本文中使用的計算需求(如上所述)映射到 Snapchat 示例 的規模上。如引言中所述,Snap 每天接收約 35 億段視頻 。假設我們通過分類每段視頻的 N 張截圖來進行排序任務,那么每天將產生 35 × N 億張圖像數據 。
我們假設該任務配備了 L 名標注員 ,每人每天可生成 4000 個標簽 。為了匹配我們論文中的 500 梯度步訓練預算,我們預計所需的總計算預算應至少為:
B = (4000L / 10000) × 500 = 200L
由于我們在實驗中完成 500 步訓練需要 0.5 個 GPU 小時,因此 200L 步相當于約 100L 個 GPU 小時 。
也就是說,當 B/L(即每個標注人員對應的計算預算)約為 100 個 GPU 小時 (相當于使用 4 塊 GPU 運行將近一整天),我們的算法可以在預定義的排序任務中優于其他基線方法。
此時,性能的上限與標簽率相關,其值為:
4000L / (3.5 × 10?N)
我們注意到,所需時間與 GPU 類型 、數據加載時間 、GPU 通信時間 等多種因素密切相關。例如,在一個數據加載較慢的系統中,可能需要多達 18 個 GPU 小時 。通過與加速領域專家的合作,B/L(算法收斂所需的人均標注預算) 可以進一步降低。
原文鏈接:https://arxiv.org/pdf/2404.12766
特別聲明:以上內容(如有圖片或視頻亦包括在內)為自媒體平臺“網易號”用戶上傳并發布,本平臺僅提供信息存儲服務。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.