热门问题
时间线
聊天
视角

神經正切核

来自维基百科,自由的百科全书

Remove ads

神經正切核(英語:neural tangent kernel,簡稱NTK)是一種核函數,用於描述深度人工神經網絡梯度下降訓練過程中的動態演變。NTK將核方法的理論工具引入人工神經網絡的研究之中。

通常而言,核函數是一類正半定的對稱函數,衡量兩個輸入之間的某種相似性。而NTK則是從一個具體的神經網絡中推導而來的特殊核函數。一般情況下,當神經網絡的參數在訓練過程中更新時,NTK也會隨之變化。然而,當神經網絡的層寬趨於無窮大時,NTK會收斂成為一個固定不變的函數。這揭示了寬神經網絡的訓練與核方法之間的對偶性:無限寬度極限下的梯度下降完全等價於使用NTK的核梯度下降。這意味著,通過梯度下降來最小化神經網絡的最小二乘損失,與使用NTK的無核回歸所得到的均值估計是一致的。這種對偶性使得我們能夠用簡單的閉式方程來描述寬神經網絡的訓練動態、泛化能力和預測結果。

NTK最早由Arthur Jacot、Franck Gabriel和Clément Hongler於2018年提出[1],他們利用NTK研究了全連接神經網絡的收斂和泛化特性。隨後的一些研究[2][3]則將NTK的分析框架推廣到了其他類型的神經網絡架構中。實際上,NTK所揭示的現象並非神經網絡所特有,在更一般的非線性模型中通用適當的縮放也能觀察到類似的行為。[4]

Remove ads

主要結論

表示一個通過給定神經網絡計算的純量函數,其中是輸入,是網絡中的所有參數,則相應的神經正切核可以定義為[1]

由於NTK可以表示為經特徵映射後兩個輸入之間的點積(此處神經網絡函數的梯度充當了將輸入映射到高維空間的特徵映射函數),這保證了NTK具有對稱性半正定性。因此,NTK是一個有效的核函數。

現在我們考慮一個全連接神經網絡,其參數從任意均值為零的分布中進行獨立同分布採樣。這種對函數的隨機初始化,會相應地使形成一個特定的函數分布。我們的目標便是分析該分布在初始化及整個梯度下降訓練過程中的統計特性。為了直觀地展現該分布,我們可以構建一個神經網絡的集成(ensemble),即從的初始分布中多次採樣,並按照完全相同的流程來訓練每一個神經網絡實例。

Thumb
初始化時,寬神經網絡的集成是一個零均值的高斯過程。在基於均方誤差梯度下降訓練期間,該集成則會根據NTK進行演化。收斂後的集成也是一個高斯過程,其分布的均值為無脊核回歸的解,方差則在所有訓練數據點上降為零。此處的神經網絡是一個純量函數,其訓練數據從單位圓上採樣得到。

神經網絡中每層的神經元數量稱為該層的寬度。假設我們將每個隱藏層的寬度都設為無窮大,並使用梯度下降法(並配合一個足夠小的學習率)來訓練這個神經網絡。在此無限寬度極限之下,神經網絡會出現一些非常理想的特性:

  • 在訓練開始前的初始化階段,神經網絡集成是一個零均值高斯過程(GP)。[5]這意味著函數分布是一種最大熵分布,其均值為,協方差為,其中由神經網絡的架構所決定。換言之,神經網絡函數的分布在初始化時除了一階矩(均值)和二階矩(協方差)之外沒有其他結構。這一結論遵循中心極限定理
  • NTK具有確定性[1][6],即其獨立於參數的隨機初始化。
  • NTK在整個訓練過程中保持不變。[1][6]
  • 儘管每個參數在訓練過程中的變化都微乎其微,但它們的共同作用卻足以對網絡的最終輸出產生顯著的改變,從而實現有效的學習。[6]
  • 正由於在訓練過程中每個參數的變化都可以忽略不計,神經網絡可以被線性化,即可以通過關於初始參數的一階泰勒展開式來近似:[6](但神經網絡對於輸入仍保持非線性的特徵)
  • 神經網絡的訓練動態等價於使用NTK作為核函數的核梯度下降[1]如果損失函數均方誤差,那麼的最終分布仍然是高斯過程,但均值和協方差則會發生變化。[1][6]具體而言,其均值會收斂到以NTK為核的無核回歸所給出的解,而協方差可以用NTK和初始協方差來表示。同時可以證明,集成的方差在所有訓練數據點處都會降為零。換句話說,無論如何隨機初始化,神經網絡最終總能完美地擬合所有訓練數據。

從物理學的角度來看,NTK可以被視為一種哈密頓量,因為當神經網絡通過無窮小的步長(即連續時間極限)進行梯度下降訓練時,它描述了網絡中可觀測量隨時間演變的過程。[7]

Remove ads

應用

無脊核回歸和核梯度下降

核方法是一類機器學習算法,其特性是僅使用輸入點之間的成對關係進行計算,而不依賴於輸入的具體值。所有這些成對關係都可以完全由一個核函數所描述。核函數是一個對稱、半正定的函數,它接收兩個輸入,並返回一個表示它們之間某種相似度的值。一個與此完全等效的定義是,存在某種特徵映射,使得核函數可以表示為經映射後輸入值的點積

核方法的性質取決於核函數的選擇。我們以經典的線性回歸為例來說明。假設有個訓練樣本,它們都由一個線性函數生成,而我們的任務則是通過這些樣本來學習權重,使其儘可能地接近真實權重。為實現這一目標,我們可以最小化模型預測值與真實訓練值之間的均方誤差

對於上述的最小化問題,存在顯式解:。其中,是由所有訓練輸入作為列向量構成的矩陣,而則是所有訓練輸出構成的向量。求得後,便可以對任意新的輸入做出預測:

我們可以將上述結果改寫為[8]在這種形式下,問題的解可以完全通過輸入之間的點積來表達。這一發現意味著我們得以將線性回歸進行推廣。此時我們不再直接計算輸入之間的點積,而是先用一個給定的特徵映射對輸入進行變換,然後再計算變換後的點積。如前文所述,這可以通過一個核函數來實現。於是,我們能夠得到無脊核回歸的預測公式:

如果核矩陣奇異陣,可使用穆爾-彭羅斯廣義逆來代替逆矩陣。該回歸方程被稱為「無脊」回歸,是因為其公式中缺少了脊正則化項。

從這個角度來看,線性回歸可以看作是核回歸的一個特例,它對應於使用恆等特徵映射的核回歸。反過來,核回歸也可看作是在特徵空間中的線性回歸。但其通常在輸入空間中是非線性的,這也是核算法的核心優勢所在。

正如我們可以使用梯度下降這類等迭代優化算法來來解線性回歸問題,我們同樣也可以使用核梯度下降來求解核回歸問題。這等價於在特徵空間中使用標準的梯度下降。對於線性回歸而言,如果權重向量的初始值接近於零,那麼最小二乘梯度下降會收斂到最小範數解,即在所有能夠完美擬合訓練數據的解中歐幾里得範數最小的解。類似地,核梯度下降會得到再生核希爾伯特空間英語Reproducing kernel Hilbert space範數最小的解。這一現象被則為梯度下降的隱式正則化。

NTK在無限寬神經網絡與核方法之間建立了嚴格的對應關係:當使用最小二乘損失訓練時,無限寬神經網絡給出的預測,其期望與以NTK作為核函數的無脊核回歸所得到的預測是一致的。這表明,對於採用NTK參數化的大規模神經網絡,其性能可以通過選取合適核函數的核方法加以復現。[1][2]

Remove ads

過參數化、插值和泛化

在過參數化(overparametrization)的模型中,可調參數的數量大於訓練樣本的數量。此種情況下,模型能夠記憶(完美擬合)所有的訓練數據。因此,過參數化的模型會對訓練數據進行插值,使其在訓練集上基本實現零誤差。[9]

Thumb
現代過參數化的模型儘管具有對訓練集進行插值(記憶)的能力,但仍能實現了較低的泛化誤差。[9]通過研究高維核回歸的泛化特性可以解釋這一現象。

核回歸通常被視為一種非參數機器學習算法,因為一旦選定了核函數,模型就不再有需要學習的顯式參數了。另外核回歸還可以看作是特徵空間中的線性回歸,因此「有效」的參數數量相當於特徵空間的維數。因此,研究具有高維特徵映射的核方法,可以讓我們對嚴重過參數化的模型有更深的了解。

我們以泛化問題為例。根據經典統計學,記憶會導致模型擬合訓練數據中的噪聲信號,從而損害其在未見數據上的預測能力。為了避免這種情況,傳統機器學習算法通常會引入正則化來抑制這種過擬合噪聲的傾向。然而令人驚訝的是,往往嚴重過度參數化的現代神經網絡,即便在沒有顯式正則化的情況下,依然表現出了優異的泛化能力。[9][10]於是我們可以利用無脊核回歸來研究過參數化神經網絡的泛化特性。有研究[11][12][13]推導出了描述高維核回歸期望泛化誤差的方程,而這些結果可以用來解釋足夠寬的神經網絡在經過最小二乘損失訓練後的泛化能力。

Remove ads

全局最小值

對於具有全局最小值損失泛函,如果NTK在訓練期間始終保持正定,那麼當時,神經網絡的損失就能保證收斂到該全局最小值。這種正定性已在一些情況下得到證實,從而首次證明了足夠寬的神經網絡在訓練過程確實能夠收斂到全局最小值。[1][14][15][16][17][18]

Remove ads

擴展與局限

NTK可用於研究各種神經網絡架構[2],如卷積神經網絡(CNN)[19]循環神經網絡(RNN) 和Transformer等。[20]在分析這些架構時,無限寬度極限的含義也相應有所變化,通常指在保持網絡層數固定的情況下增加參數的數量。以CNN為例,「寬度」對應的是卷積層的通道數。

在標準的核機制下,寬神經網絡中的各個參數在訓練過程中幾乎不會發生變化。然而,這意味著無限寬神經網絡無法進行特徵學習。而特徵學習被廣泛認為是實際應用中深度神經網絡的一個重要特性。值得注意的是,這一點並不是無限寬神經網絡的普遍特徵,而很大程度上是源於我們將寬度推向無窮時採取的一種特定的參數縮放方式。事實上,已有研究[21][22][23][24]探索不同縮放極限下的無限寬神經網絡,在這些情形下神經網絡與核回歸之間的對偶性不復存在,而特徵學習則得以出現。此外,還有研究[25]引入了「神經正切層級」(neural tangent hierarchy)的概念來描述有限寬度效應及其與特徵學習間的聯繫。

Neural Tangents是一個由Google開發的免費開源Python庫,可以用於計算和推斷與各常見神經網絡架構相對應的無限寬度NTK和神經網絡高斯過程[26]此外,還有一個名為scikit-ntk的庫提供與scikit-learn兼容的NTK工具。[27]

Remove ads

細節

當通過梯度下降法來優化一個神經網絡的參數並最小化其經驗損失時,NTK在整個訓練過程中主導著網絡輸出函數的動態變化。

情形一:純量輸出

一個輸出為純量的神經網絡可以被看作是由參數定義的一系列函數

其對應的NTK是一個核函數,定義為

在核方法的框架下,是與特徵映射相關聯的核函數。為了理解這一個核函數如何驅動神經網絡的訓練動態,我們可以考慮一個數據集,其中是輸入,是對應的純量標籤,則是損失函數。於是,定義在上的經驗損失可表示為

當通過連續時間梯度下降來訓練神經網格以擬合數據集(即最小化)時,其參數的演化遵循以下常微分方程

在訓練過程中,網絡輸出函數本身的演化則可以由一個以NTK表示的微分方程來描述:

這個方程揭示了訓練期間NTK是如何驅動在函數空間中的變化的。

Remove ads

情形二:向量輸出

當神經網絡的輸出是維度為的向量時,它可以看作是由參數定義的一系列函數

此種情況下,NTK是一個具有矩陣值的核函數,將輸入映射到的矩陣上,定義為

經驗風險最小化的過程與純量情形類似,主要的區別在於損失函數改為採用向量輸入,而其在連續時間梯度下降訓練過程中函數空間的演化動態也同樣由NTK所主導。對應的演化方程為:

這一方程是情形一中純量輸出方程的直接推廣。

Remove ads

解釋

在訓練的每一步,每個訓練數據點都會對任意一個輸入所對應的網絡輸出的演化產生影響。更具體地說,對於第個訓練樣本,其產生的梯度損失會對的更新作出貢獻,而NTK的值則決定了該貢獻的大小。在純量輸出的情形下,這一過程可以用以下離散時間的梯度下降更新公式來直觀地表示:

Remove ads

訓練中保持不變的確定性NTK

考慮一個包含全連接層的神經網絡,其各層寬度為。該網絡可以表示成多層函數的複合 ,其中每一層的都是先進行一次仿射變換,再逐點應用一個非線性激活函數。網絡的全部可訓練參數定義了這些仿射變換,並在訓練開始前以獨立同分布的方式進行隨機初始化。

隨著網絡寬度的增加,NTK的尺度會受到及參數初始化的影響。為得到一個性質良好的極限,有研究提出了所謂的NTK參數化: 。採用這種參數化並以標準常態分布來初始化所有參數時,可以確保當網絡寬度趨於無窮時,NTK會收斂到一個有限且非平凡的極限不僅是一個確定性(非隨機)的函數 ,而且在訓練過程中保持不變。

可以將表示為, 而則能通過以下遞歸方程組計算得到:

在上述方程中,定義了由舊核和函數生成新核的方式,其計算依賴於高斯期望

式中的被稱為神經網絡的激活核。[28][29][5]

Remove ads

訓練中參數的線性化

NTK描述了神經網絡在函數空間中的動態演化,而與相對應的另一個視角則是網絡在參數空間中的演化。在無限寬度極限下,這兩個視角之間的關聯變得十分有趣。與NTK在訓練中保持恆定這一現象同時發生的,是神經網絡在整個訓練過程中的變化都可以由其在初始化參數處的一階泰勒展開很好地近似:[6]

參考文獻

Loading related searches...

Wikiwand - on

Seamless Wikipedia browsing. On steroids.

Remove ads