選單

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域適應聚類方法

機器之心專欄

作者:蘇永怡

華南理工、A*STAR 團隊和鵬城實驗室聯合提出了針對測試階段訓練(TTT)問題的系統性分類準則。

域適應是解決遷移學習的重要方法,當前域適應當法依賴原域和目標域資料進行同步訓練。當源域資料不可得,同時目標域資料不完全可見時,測試階段訓練(Test- Time Training)成為新的域適應方法。當前針對 Test-Time Training(TTT)的研究廣泛利用了自監督學習、對比學習、自訓練等方法,然而,如何定義真實環境下的 TTT 卻被經常忽略,以至於不同方法間缺乏可比性。

近日,華南理工、A*STAR 團隊和鵬城實驗室聯合提出了針對 TTT 問題的系統性分類準則,透過區分方法是否具備順序推理能力(Sequential Inference)和是否需要修改源域訓練目標,對當前方法做了詳細分類。同時,提出了基於目標域資料定錨聚類(Anchored Clustering)的方法,在多種 TTT 分類下取得了最高的分類準確率,本文對 TTT 的後續研究指明瞭正確的方向,避免了實驗設定混淆帶來的結果不可比問題。研究論文已被 NeurIPS 2022 接收。

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域適應聚類方法

論文:https://arxiv。org/abs/2206。02721

程式碼:https://github。com/Gorilla-Lab-SCUT/TTAC

一、引言

深度學習的成功主要歸功於大量的標註資料和訓練集與測試集獨立同分布的假設。在一般情況下,需要在合成數據上訓練,然後在真實資料上測試時,以上假設就沒辦法滿足,這也被稱為域偏移。為了緩解這個問題,域適應 (Domain Adaptation, DA) 誕生了。現有的 DA 工作要麼需要在訓練期間訪問源域和目標域的資料,要麼同時在多個域進行訓練。前者需要模型在做適應 (Adaptation) 訓練期間總是能訪問到源域資料,而後者需要更加昂貴的計算量。為了降低對源域資料的依賴,由於隱私問題或者儲存開銷不能訪問源域資料,無需源域資料的域適應 (Source-Free Domain Adaptation, SFDA) 解決無法訪問源域資料的域適應問題。作者發現 SFDA 需要在整個目標資料集上訓練多個輪次才能達到收斂,在面對流式資料需要及時做出推斷預測的時候 SFDA 無法解決此類問題。這種面對流式資料需要及時適應並做出推斷預測的更現實的設定,被稱為測試時訓練 (Test-Time Training, TTT) 或測試時適應(Test-Time Adaptation, TTA)。

作者注意到在社群裡對 TTT 的定義存在混亂從而導致比較的不公平。論文以兩個關鍵的因素對現有的 TTT 方法進行分類:

對於資料是流式出現的並需要對當前出現的資料作出及時預測的,稱之為單輪適應協議(One-Pass Adaptation);對於其他不符合以上設定的稱為多輪適應協議(Multi-Pass Adaptation),模型可能需要在整個測試集上進行多輪次的更新後,再進行從頭到尾的推斷預測。

根據是否需要修改源域的訓練損失方程,比如引入額外的自監督分支以達到更有效的 TTT。

這篇論文的目標是解決最現實和最具挑戰性的 TTT 協議,即單輪適應並無需修改訓練損失方程。這個設定類似於 TENT[1]提出的 TTA,但不限於使用來自源域的輕量級資訊,如特徵的統計量。鑑於 TTT 在測試時高效適應的目標,該假設在計算上是高效的,並大大提高了 TTT 的效能。作者將這個新的 TTT 協議命名為順序測試時訓練(sequential Test Time Training, sTTT)。

除了以上對不同 TTT 方法的分類外,論文還提出了兩個技術讓 sTTT 更加有效和準確:

論文提出了測試時錨定聚類 (Test-Time Anchored Clustering, TTAC) 方法。

為了降低錯誤偽標籤對聚類更新的影響,論文根據網路對樣本的預測穩定性和自信度對偽標籤進行過濾。

二、方法介紹

論文分了四部分來闡述所提出的方法,分別是 1)介紹測試時訓練 (TTT) 的錨定聚類模組,如圖 1 中的 Anchored Clustering 部分;2)介紹用於過濾偽標籤的一些策略,如圖 1 中的 Pseudo Label Filter 部分;3)不同於 TTT++[2]中的使用 L2 距離來衡量兩個分佈的距離,作者使用了 KL 散度來度量兩個全域性特徵分佈間的距離;4)介紹在測試時訓練 (TTT) 過程的特徵統計量的有效更新迭代方法。最後第五小節給出了整個演算法的過程程式碼。

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域適應聚類方法

第一部分 在錨定聚類裡,作者首先使用混合高斯對目標域的特徵進行建模,其中每個高斯分量代表一個被發現的聚類。然後,作者使用源域中每個類別的分佈作為目標域分佈的錨點來進行匹配。透過這種方式,測試資料特徵可以同時形成叢集,並且叢集與源域類別相關聯,從而達到了對目標域的推廣。概述來說就是,將源域和目標域的特徵分別根據類別資訊建模成:

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域適應聚類方法

然後透過 KL 散度度量兩個混合高斯分佈的距離,並透過減少 KL 散度來達到兩個域特徵的匹配。可是,在兩個混合高斯分佈上直接求解 KL 散度並沒有閉式解,這導致了無法使用有效的梯度最佳化方法。在這篇論文中,作者在源域和目標域中分配相同數量的叢集,每個目標域叢集被分配給一個源域叢集,這樣就可以將整個混合高斯的 KL 散度求解變成了各對高斯之間的 KL 散度之和。如下式:

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域適應聚類方法

上式的閉式解形式為:

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域適應聚類方法

在公式 2 中,源域叢集的引數可以線下收集完,而且由於只用到了輕量化統計資料,所以不會導致隱私洩漏問題且只使用了少量的計算和儲存開銷。對於目標域的變數,涉及到了偽標籤的使用,作者為此設計了一套有效的且輕量的偽標籤過濾策略。

第二部分 偽標籤過濾的策略主要分為兩部分:

1)時序上一致性預測的過濾:

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域適應聚類方法

2)根據後驗機率的過濾:

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域適應聚類方法

最後,使用過濾後的樣本來求解目標域叢集的統計量:

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域適應聚類方法

第三部分 由於在錨定聚類中,部分被濾除的樣本並沒有參與目標域的估計。作者還對所有測試樣本進行全域性特徵對齊,類似錨定聚類中對叢集的做法,這裡將所有樣本看作一個整體的叢集,在源域和目標域分別定義

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域適應聚類方法

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域適應聚類方法

然後再次以最小化 KL 散度為目標對齊全域性特徵分佈:

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域適應聚類方法

第四部分 以上三部分都在介紹一些域對齊的手段,但在 TTT 過程中,想要估計一個目標域的分佈是不簡單的,因為我們無法觀測整個目標域的資料。在前沿的工作中,TTT++[2]使用了一個特徵佇列來儲存過去的部分樣本,來計算一個區域性分佈來估計整體分佈。但這樣不但帶來了記憶體開銷還導致了精度與記憶體之間的 trade off。在這篇論文中,作者提出了迭代更新統計量的方式來緩解記憶體開銷。具體的迭代更新式子如下:

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域適應聚類方法

總的來說,整個演算法如下演算法 1 所示:

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域適應聚類方法

三、實驗結果

正如引言部分所說,這篇論文中作者非常注重不同 TTT 策略下的不同方法的公平比較。作者將所有 TTT 方法根據以下兩個關鍵因素來分類:1)是否單輪適應協議 (One-Pass Adaptation) 和 2)修改源域的訓練損失方程,分別記為 Y/N 表示需要或不需要修改源域訓練方程,O/M 表示單輪適應或多輪適應。除此之外,作者在 6 個基準的資料集上進行了充分的對比實驗和一些進一步的分析。

如表一所示,TTT++[2]同時出現在了 N-O 和 Y-O 的協議下,是因為 TTT++[2]擁有一個額外的自監督分支,我們在 N-O 協議下將不新增自監督分支的損失,而在 Y-O 下可以正常使用此分子的損失。TTAC 在 Y-O 下也是使用了跟 TTT++[2]一樣的自監督分支。從表中可以看到,在所有的 TTT 協議下所有資料集下,TTAC 均取得到最優的結果;在 CIFAR10-C 和 CIFAR100-C 資料集上,TTAC 都取得了 3% 以上的提升。從表 2 - 表 5 分別是 ImageNet-C、CIFAR10。1、VisDA 上的資料,TTAC 均取到了最優的結果。

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域適應聚類方法

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域適應聚類方法

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域適應聚類方法

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域適應聚類方法

此外,作者在多個 TTT 協議下同時做了嚴格的消融實驗,清晰地看出了每個部件的作用,如表 6 所示。首先從 L2 Dist 和 KLD 的對比中,可以看出使用 KL 散度來衡量兩個分佈具有更優的效果;其次,發現如果單單使用 Anchored Clustering 或單獨使用偽標籤監督提升只有 14%,但如果結合了 Anchored Cluster 和 Pseudo Label Filter 就可以看到效能顯著提高 29。15% -> 11。33%。這也可以看出每個部件的必要性和有效的結合。

NeurIPS 2022 | 如何定義測試階段訓練?順序推理和域適應聚類方法

最後,作者在正文的尾部從五個維度對 TTAC 展開了充分的分析,分別是 sTTT (N-O)下的累計表現、TTAC 特徵的 TSNE 視覺化、源域無關的 TTT 分析、測試樣本佇列和更新輪次的分析、以 wall-clock 時間度量計算開銷。還有更多有趣的證明和分析會展示在文章的附錄中。

四、總結

本文只是粗糙地介紹了 TTAC 這篇工作的貢獻點:對已有 TTT 方法的分類比較、提出的方法、以及各個 TTT 協議分類下的實驗。論文和附錄中會有更加詳細的討論和分析。我們希望這項工作能夠為 TTT 方法提供一個公平的基準,未來的研究應該在各自的協議內進行比較。

[1] Dequan Wang, Evan Shelhamer, Shaoteng Liu, Bruno Olshausen, and Trevor Darrell。 Tent: Fully test-time adaptation by entropy minimization。 In International Conference on Learning Representations, 2021。

[2] Yuejiang Liu, Parth Kothari, Bastienvan Delft, Baptiste Bellot-Gurlet, Taylor Mordan, and Alexandre Alahi。 Ttt++: When does self-supervised test-time training fail or thrive? In Advances in Neural Information Processing Systems, 2021。

開啟App看更多精彩內容