論文鏈接:https://arxiv.org/abs/2305.17476
代碼鏈接:
https://github.com/ML-GSAI/Understanding-GDA
概述
生成式數(shù)據(jù)擴增通過條件生成模型生成新樣本來擴展數(shù)據(jù)集,從而提高各種學(xué)習(xí)任務(wù)的分類性能。然而,很少有人從理論上研究生成數(shù)據(jù)增強的效果。為了填補這一空白,我們在這種非獨立同分布環(huán)境下構(gòu)建了基于穩(wěn)定性的通用泛化誤差界?;谕ㄓ玫姆夯?,我們進一步了探究了高斯混合模型和生成對抗網(wǎng)絡(luò)的學(xué)習(xí)情況。
在這兩種情況下,我們證明了,雖然生成式數(shù)據(jù)增強并不能享受更快的學(xué)習(xí)率,但當(dāng)訓(xùn)練集較小時,它可以在一個常數(shù)的水平上提高學(xué)習(xí)保證,這在發(fā)生過擬合時是非常重要的。最后,高斯混合模型的仿真結(jié)果和生成式對抗網(wǎng)絡(luò)的實驗結(jié)果都支持我們的理論結(jié)論。
主要的理論結(jié)果
2.1 符號與定義
讓 作為數(shù)據(jù)輸入空間, 作為標(biāo)簽空間。定義 為 上的真實分布。給定集合 ,我們定義 為去掉第 個數(shù)據(jù)后剩下的集合, 為把第 個數(shù)據(jù)換成 后的集合。我們用 表示 total variation distance。
我們讓 為所有從 到 的所有可測函數(shù), 為學(xué)習(xí)算法,為從數(shù)據(jù)集 中學(xué)到的映射。對于一個學(xué)到的映射 和損失函數(shù),真實誤差 被定義為。相應(yīng)的經(jīng)驗的誤差 被定義為。
我們文章理論推導(dǎo)采用的是穩(wěn)定性框架,我們稱算法 相對于損失函數(shù) 是一致 穩(wěn)定的,如果
2.2 生成式數(shù)據(jù)增強
給定帶有 個 i.i.d. 樣本的 數(shù)據(jù)集,我們能訓(xùn)練一個條件生成模型 ,并將學(xué)到的分布定義為 ?;谟?xùn)練得到的條件生成模型,我們能生成一個新的具有 個 i.i.d. 樣本的數(shù)據(jù)集 。我們記增廣后的數(shù)據(jù)集 大小為 。我們可以在增廣后的數(shù)據(jù)集上學(xué)到映射 。為了理解生成式數(shù)據(jù)增強,我們關(guān)心泛化誤差 。據(jù)我們所知,這是第一個理解生成式數(shù)據(jù)增強泛化誤差的工作。2.3 一般情況
我們可以對于任意的生成器和一致 穩(wěn)定的分類器,推得如下的泛化誤差:▲ general一般來說,我們比較關(guān)心泛化誤差界關(guān)于樣本數(shù) 的收斂率。將 看成超參數(shù),并將后面兩項記為 generalization error w.r.t. mixed distribution,我們可以定義如下的“最有效的增強數(shù)量”:在這個設(shè)置下,并和沒有數(shù)據(jù)增強的情況進行對比(),我們可以得到如下的充分條件,它刻畫了生成式數(shù)據(jù)增強何時(不)能夠促進下游分類任務(wù),這和生成模型學(xué)習(xí)分的能力息息相關(guān):
▲ corollary
2.4 高斯混合模型為了驗證我們理論的正確性,我們先考慮了一個簡單的高斯混合模型的 setting。 混合高斯分布。我們考慮二分類任務(wù) 。我們假設(shè)真實分布滿足 and 。我們假設(shè) 的分布是已知的。 線性分類器。我們考慮一個被 參數(shù)化的分類器,預(yù)測函數(shù)為 。給定訓(xùn)練集, 通過最小化負(fù)對數(shù)似然損失函數(shù)得到,即最小化學(xué)習(xí)算法將會推得 ,which satisfies 條件生成模型。我們考慮參數(shù)為 的條件生成模型,其中 以及 。給定訓(xùn)練集,讓 為第 類的樣本量,條件生成模型學(xué)到
它們是 和 的無偏估計。我們可以從這個條件模型中進行采樣,即 ,,其中 。 我們在高斯混合模型的場景下具體計算 Theorem 3.1 中的各個項,可以推得
▲ GMM
- 當(dāng)數(shù)據(jù)量 足夠時,即使我們采用“最有效的增強數(shù)量”,生成式數(shù)據(jù)增強也難以提高下游任務(wù)的分類性能。
- 當(dāng)數(shù)據(jù)量 較小的,此時主導(dǎo)泛化誤差的是維度等其他項,此時進行生成式數(shù)據(jù)增強可以常數(shù)級降低泛化誤差,這意味著在過擬合的場景下,生成式數(shù)據(jù)增強是很有必要的。
2.5 生成對抗網(wǎng)絡(luò)
我們也考慮了深度學(xué)習(xí)的情況。我們假設(shè)生成模型為 MLP 生成對抗網(wǎng)絡(luò),分類器為 層 MLP 或者 CNN。損失函數(shù)為二元交叉熵,優(yōu)化算法為 SGD。我們假設(shè)損失函數(shù)平滑,并且第 層的神經(jīng)網(wǎng)絡(luò)參數(shù)可以被 控制。我們可以推得如下的結(jié)論:▲ GAN
- 當(dāng)數(shù)據(jù)量 足夠時,生成式數(shù)據(jù)增強也難以提高下游任務(wù)的分類性能,甚至?xí)夯?/span>
- 當(dāng)數(shù)據(jù)量 較小的,此時主導(dǎo)泛化誤差的是維度等其他項,此時進行生成式數(shù)據(jù)增強可以常數(shù)級降低泛化誤差,同樣地,這意味著在過擬合的場景下,生成式數(shù)據(jù)增強是很有必要的。
實驗
3.1 高斯混合模型模擬實驗
我們在混合高斯分布上驗證我們的理論,我們調(diào)整數(shù)據(jù)量 ,數(shù)據(jù)維度 以及 。實驗結(jié)果如下圖所示:
▲ simulation
- 觀察圖(a),我們可以發(fā)現(xiàn)當(dāng) 相對于 足夠大的時候,生成式數(shù)據(jù)增強的引入并不能明顯改變泛化誤差。
- 觀察圖(d),我們可以發(fā)現(xiàn)當(dāng) 固定時,真實的泛化誤差確實是 階的,且隨著增強數(shù)量 的增大,泛化誤差呈現(xiàn)常數(shù)級的降低。
- 另外 4 張圖,我們選取了兩種情況,驗證了我們的 bound 能在趨勢上一定程度上預(yù)測泛化誤差。
▲ deep
- 在沒有額外數(shù)據(jù)增強的時候, 較小,分類器陷入了嚴(yán)重的過擬合。此時,即使選取的 cDCGAN 很古早(bad GAN),生成式數(shù)據(jù)增強都能帶來明顯的提升。
- 在有額外數(shù)據(jù)增強的時候, 充足。此時,即使選取的 StyleGAN 很先進(SOTA GAN),生成式數(shù)據(jù)增強都難以帶來明顯的提升,在 50k 和 100k 增強的情況下甚至都造成了一致的損害。
-
我們也測試了一個 SOTA 的擴散模型 EDM,發(fā)現(xiàn)即使在有額外數(shù)據(jù)增強的時候,生成式數(shù)據(jù)增強也能提升分類效果。這意味著擴散模型學(xué)習(xí)分布的能力可能會優(yōu)于 GAN。
原文標(biāo)題:NeurIPS 2023 | 如何從理論上研究生成式數(shù)據(jù)增強的效果?
文章出處:【微信公眾號:智能感知與物聯(lián)網(wǎng)技術(shù)研究所】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
-
物聯(lián)網(wǎng)
+關(guān)注
關(guān)注
2900文章
44062瀏覽量
370247
原文標(biāo)題:NeurIPS 2023 | 如何從理論上研究生成式數(shù)據(jù)增強的效果?
文章出處:【微信號:tyutcsplab,微信公眾號:智能感知與物聯(lián)網(wǎng)技術(shù)研究所】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
相關(guān)推薦
評論