0
  • 聊天消息
  • 系統(tǒng)消息
  • 評(píng)論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會(huì)員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識(shí)你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

為什么SAM可以實(shí)現(xiàn)更好的泛化?如何在Pytorch中實(shí)現(xiàn)SAM?

深度學(xué)習(xí)自然語(yǔ)言處理 ? 來(lái)源:AI公園 ? 作者:Sean Benhur J ? 2021-04-03 14:48 ? 次閱讀

導(dǎo)讀 使用SAM(銳度感知最小化),優(yōu)化到損失的最平坦的最小值的地方,增強(qiáng)泛化能力。

動(dòng)機(jī)來(lái)自先前的工作,在此基礎(chǔ)上,我們提出了一種新的、有效的方法來(lái)同時(shí)減小損失值和損失的銳度。具體來(lái)說(shuō),在我們的處理過(guò)程中,進(jìn)行銳度感知最小化(SAM),在領(lǐng)域內(nèi)尋找具有均勻的低損失值的參數(shù)。這個(gè)公式產(chǎn)生了一個(gè)最小-最大優(yōu)化問(wèn)題,在這個(gè)問(wèn)題上梯度下降可以有效地執(zhí)行。我們提出的實(shí)證結(jié)果表明,SAM在各種基準(zhǔn)數(shù)據(jù)集上都改善了的模型泛化。

深度學(xué)習(xí)中,我們使用SGD/Adam等優(yōu)化算法在我們的模型中實(shí)現(xiàn)收斂,從而找到全局最小值,即訓(xùn)練數(shù)據(jù)集中損失較低的點(diǎn)。但等幾種研究表明,許多網(wǎng)絡(luò)可以很容易地記住訓(xùn)練數(shù)據(jù)并有能力隨時(shí)overfit,為了防止這個(gè)問(wèn)題,增強(qiáng)泛化能力,谷歌研究人員發(fā)表了一篇新論文叫做Sharpness Awareness Minimization,在CIFAR10上以及其他的數(shù)據(jù)集上達(dá)到了最先進(jìn)的結(jié)果。

在本文中,我們將看看為什么SAM可以實(shí)現(xiàn)更好的泛化,以及我們?nèi)绾卧赑ytorch中實(shí)現(xiàn)SAM。

SAM的原理是什么?

在梯度下降或任何其他優(yōu)化算法中,我們的目標(biāo)是找到一個(gè)具有低損失值的參數(shù)。但是,與其他常規(guī)的優(yōu)化方法相比,SAM實(shí)現(xiàn)了更好的泛化,它將重點(diǎn)放在領(lǐng)域內(nèi)尋找具有均勻的低損失值的參數(shù)(而不是只有參數(shù)本身具有低損失值)上。

由于計(jì)算鄰域參數(shù)而不是計(jì)算單個(gè)參數(shù),損失超平面比其他優(yōu)化方法更平坦,這反過(guò)來(lái)增強(qiáng)了模型的泛化。

(左))用SGD訓(xùn)練的ResNet收斂到的一個(gè)尖銳的最小值。(右)用SAM訓(xùn)練的相同的ResNet收斂到的一個(gè)平坦的最小值。

注意:SAM不是一個(gè)新的優(yōu)化器,它與其他常見(jiàn)的優(yōu)化器一起使用,比如SGD/Adam。

在Pytorch中實(shí)現(xiàn)SAM

在Pytorch中實(shí)現(xiàn)SAM非常簡(jiǎn)單和直接

import torch

class SAM(torch.optim.Optimizer):

def __init__(self, params, base_optimizer, rho=0.05, **kwargs):

assert rho 》= 0.0, f“Invalid rho, should be non-negative: {rho}”

defaults = dict(rho=rho, **kwargs)

super(SAM, self).__init__(params, defaults)

self.base_optimizer = base_optimizer(self.param_groups, **kwargs)

self.param_groups = self.base_optimizer.param_groups

@torch.no_grad()

def first_step(self, zero_grad=False):

grad_norm = self._grad_norm()

for group in self.param_groups:

scale = group[“rho”] / (grad_norm + 1e-12)

for p in group[“params”]:

if p.grad is None: continue

e_w = p.grad * scale.to(p)

p.add_(e_w) # climb to the local maximum “w + e(w)”

self.state[p][“e_w”] = e_w

if zero_grad: self.zero_grad()

@torch.no_grad()

def second_step(self, zero_grad=False):

for group in self.param_groups:

for p in group[“params”]:

if p.grad is None: continue

p.sub_(self.state[p][“e_w”]) # get back to “w” from “w + e(w)”

self.base_optimizer.step() # do the actual “sharpness-aware” update

if zero_grad: self.zero_grad()

def _grad_norm(self):

shared_device = self.param_groups[0][“params”][0].device # put everything on the same device, in case of model parallelism

norm = torch.norm(

torch.stack([

p.grad.norm(p=2).to(shared_device)

for group in self.param_groups for p in group[“params”]

if p.grad is not None

]),

p=2

return norm

代碼取自非官方的Pytorch實(shí)現(xiàn)。

代碼解釋:

首先,我們從Pytorch繼承優(yōu)化器類來(lái)創(chuàng)建一個(gè)優(yōu)化器,盡管SAM不是一個(gè)新的優(yōu)化器,而是在需要繼承該類的每一步更新梯度(在基礎(chǔ)優(yōu)化器的幫助下)。

該類接受模型參數(shù)、基本優(yōu)化器和rho, rho是計(jì)算最大損失的鄰域大小。

在進(jìn)行下一步之前,讓我們先看看文中提到的偽代碼,它將幫助我們?cè)跊](méi)有數(shù)學(xué)的情況下理解上述代碼。

bf4472f8-92a2-11eb-8b86-12bb97331649.jpg

正如我們?cè)谟?jì)算第一次反向傳遞后的偽代碼中看到的,我們計(jì)算epsilon并將其添加到參數(shù)中,這些步驟是在上述python代碼的方法first_step中實(shí)現(xiàn)的。

現(xiàn)在在計(jì)算了第一步之后,我們必須回到之前的權(quán)重來(lái)計(jì)算基礎(chǔ)優(yōu)化器的實(shí)際步驟,這些步驟在函數(shù)second_step中實(shí)現(xiàn)。

函數(shù)_grad_norm用于返回矩陣向量的norm,即偽代碼的第10行

在構(gòu)建這個(gè)類后,你可以簡(jiǎn)單地使用它為你的深度學(xué)習(xí)項(xiàng)目通過(guò)以下的訓(xùn)練函數(shù)片段。

from sam import SAM

。。.

model = YourModel()

base_optimizer = torch.optim.SGD # define an optimizer for the “sharpness-aware” update

optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)

。。.

for input, output in data:

# first forward-backward pass

loss = loss_function(output, model(input)) # use this loss for any training statistics

loss.backward()

optimizer.first_step(zero_grad=True)

# second forward-backward pass

loss_function(output, model(input)).backward() # make sure to do a full forward pass

optimizer.second_step(zero_grad=True)

。。.

總結(jié)

雖然SAM的泛化效果較好,但是這種方法的主要缺點(diǎn)是,由于前后兩次計(jì)算銳度感知梯度,需要花費(fèi)兩倍的訓(xùn)練時(shí)間。除此之外,SAM還在最近發(fā)布的NFNETS上證明了它的效果,這是ImageNet目前的最高水平,在未來(lái),我們可以期待越來(lái)越多的論文利用這一技術(shù)來(lái)實(shí)現(xiàn)更好的泛化。

英文原文:https://pub.towardsai.net/we-dont-need-to-worry-about-overfitting-anymore-9fb31a154c81
編輯:lyn

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場(chǎng)。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問(wèn)題,請(qǐng)聯(lián)系本站處理。 舉報(bào)投訴
  • SAM
    SAM
    +關(guān)注

    關(guān)注

    0

    文章

    111

    瀏覽量

    33480
  • 深度學(xué)習(xí)
    +關(guān)注

    關(guān)注

    73

    文章

    5466

    瀏覽量

    120894
  • pytorch
    +關(guān)注

    關(guān)注

    2

    文章

    802

    瀏覽量

    13116

原文標(biāo)題:【過(guò)擬合】再也不用擔(dān)心過(guò)擬合的問(wèn)題了

文章出處:【微信號(hào):zenRRan,微信公眾號(hào):深度學(xué)習(xí)自然語(yǔ)言處理】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。

收藏 人收藏

    評(píng)論

    相關(guān)推薦

    何在 PyTorch 訓(xùn)練模型

    Dataset Dataset 類需要我們實(shí)現(xiàn) __init__ 、 __len__ 和 __getitem__ 三個(gè)方法。 __init__ 方法用于初始數(shù)據(jù)集, __len__ 返回?cái)?shù)據(jù)集中的樣
    的頭像 發(fā)表于 11-05 17:36 ?213次閱讀

    何在反激式拓?fù)?b class='flag-5'>中實(shí)現(xiàn)軟啟動(dòng)

    電子發(fā)燒友網(wǎng)站提供《如何在反激式拓?fù)?b class='flag-5'>中實(shí)現(xiàn)軟啟動(dòng).pdf》資料免費(fèi)下載
    發(fā)表于 09-04 11:09 ?0次下載
    如<b class='flag-5'>何在</b>反激式拓?fù)?b class='flag-5'>中</b><b class='flag-5'>實(shí)現(xiàn)</b>軟啟動(dòng)

    何在FPGA實(shí)現(xiàn)隨機(jī)數(shù)發(fā)生器

    分享如何在Xilinx Breadboardable Spartan-7 FPGA, CMOD S7實(shí)現(xiàn)4位偽隨機(jī)數(shù)發(fā)生器(PRNGs)。
    的頭像 發(fā)表于 08-06 11:20 ?554次閱讀
    如<b class='flag-5'>何在</b>FPGA<b class='flag-5'>中</b><b class='flag-5'>實(shí)現(xiàn)</b>隨機(jī)數(shù)發(fā)生器

    何在Tensorflow實(shí)現(xiàn)反卷積

    ,扮演著重要角色。以下將詳細(xì)闡述如何在TensorFlow實(shí)現(xiàn)反卷積,包括其理論基礎(chǔ)、TensorFlow實(shí)現(xiàn)方式、以及實(shí)際應(yīng)用
    的頭像 發(fā)表于 07-14 10:46 ?521次閱讀

    PyTorch如何實(shí)現(xiàn)多層全連接神經(jīng)網(wǎng)絡(luò)

    PyTorch實(shí)現(xiàn)多層全連接神經(jīng)網(wǎng)絡(luò)(也稱為密集連接神經(jīng)網(wǎng)絡(luò)或DNN)是一個(gè)相對(duì)直接的過(guò)程,涉及定義網(wǎng)絡(luò)結(jié)構(gòu)、初始參數(shù)、前向傳播、損失計(jì)算和反向傳播等步驟。
    的頭像 發(fā)表于 07-11 16:07 ?987次閱讀

    何在PyTorch實(shí)現(xiàn)LeNet-5網(wǎng)絡(luò)

    等人提出,主要用于手寫數(shù)字識(shí)別任務(wù)(如MNIST數(shù)據(jù)集)。下面,我將詳細(xì)闡述如何在PyTorch從頭開(kāi)始實(shí)現(xiàn)LeNet-5網(wǎng)絡(luò),包括網(wǎng)絡(luò)架構(gòu)設(shè)計(jì)、參數(shù)初始
    的頭像 發(fā)表于 07-11 10:58 ?662次閱讀

    【愛(ài)芯派 Pro 開(kāi)發(fā)板試用體驗(yàn)】+ 交互式摳圖軟件的實(shí)現(xiàn)

    軟件。 SAM的編譯 我們先介紹一下如何在開(kāi)發(fā)板上編譯原始的SAM程序原始代碼由愛(ài)芯官方開(kāi)源于 GITHUB:https://github.com/AXERA-TECH
    發(fā)表于 01-02 22:04

    一種新的分割模型Stable-SAM

    SAM、HQ-SAM、Stable-SAM在提供次優(yōu)提示時(shí)的性能比較,Stable-SAM明顯優(yōu)于其他算法。這里也推薦工坊推出的新課程《如何將深度學(xué)習(xí)模型部署到實(shí)際工程
    的頭像 發(fā)表于 12-29 14:35 ?608次閱讀
    一種新的分割模型Stable-<b class='flag-5'>SAM</b>

    【愛(ài)芯派 Pro 開(kāi)發(fā)板試用體驗(yàn)】+ 圖像分割和填充的Demo測(cè)試

    的實(shí)時(shí)分割和進(jìn)一步可選修復(fù)。原始代碼由愛(ài)芯官方開(kāi)源于 GITHUB:https://github.com/AXERA-TECH/SAM-ONNX-AX650-CPP。倉(cāng)庫(kù)內(nèi)的文本介紹了如何在AX650N
    發(fā)表于 12-26 11:22

    利用 MPLAB? Harmony v3 TCP/IP協(xié)議棧在SAM E54 MCU 上實(shí)現(xiàn)文件傳輸協(xié)議

    電子發(fā)燒友網(wǎng)站提供《利用 MPLAB? Harmony v3 TCP/IP協(xié)議棧在SAM E54 MCU 上實(shí)現(xiàn)文件傳輸協(xié)議.pdf》資料免費(fèi)下載
    發(fā)表于 12-18 11:03 ?0次下載
    利用 MPLAB? Harmony v3 TCP/IP協(xié)議棧在<b class='flag-5'>SAM</b> E54 MCU 上<b class='flag-5'>實(shí)現(xiàn)</b>文件傳輸協(xié)議

    LIO-SAM框架是什么

    LIO-SAM的全稱是:Tightly-coupled Lidar Inertial Odometry via Smoothing and Mapping,從全稱上可以看出,該算法是一個(gè)緊耦合的雷達(dá)
    的頭像 發(fā)表于 11-24 17:08 ?1124次閱讀
    LIO-<b class='flag-5'>SAM</b>框架是什么

    何在 3DICC 基于虛擬原型實(shí)現(xiàn)多芯片架構(gòu)探索

    何在 3DICC 基于虛擬原型實(shí)現(xiàn)多芯片架構(gòu)探索
    的頭像 發(fā)表于 11-23 09:04 ?441次閱讀
    如<b class='flag-5'>何在</b> 3DICC <b class='flag-5'>中</b>基于虛擬原型<b class='flag-5'>實(shí)現(xiàn)</b>多芯片架構(gòu)探索

    3d激光SLAMLIO-SAM框架介紹

    里程計(jì)的框架。 實(shí)現(xiàn)了高精度、實(shí)時(shí)的移動(dòng)機(jī)器人的軌跡估計(jì)和建圖。 本篇博客重點(diǎn)解讀LIO-SAM框架下IMU預(yù)積分功能數(shù)據(jù)初始代碼部分 LIO-SAM 的代碼主要在其主目錄內(nèi)的src
    的頭像 發(fā)表于 11-22 15:04 ?996次閱讀
    3d激光SLAMLIO-<b class='flag-5'>SAM</b>框架介紹

    何在苛刻的熱限條件下實(shí)現(xiàn)增強(qiáng)的可視計(jì)算

    電子發(fā)燒友網(wǎng)站提供《如何在苛刻的熱限條件下實(shí)現(xiàn)增強(qiáng)的可視計(jì)算.pdf》資料免費(fèi)下載
    發(fā)表于 11-15 14:19 ?0次下載
    如<b class='flag-5'>何在</b>苛刻的熱限條件下<b class='flag-5'>實(shí)現(xiàn)</b>增強(qiáng)的可視<b class='flag-5'>化</b>計(jì)算

    NTAG5連線如何使用MIFARE SAM AV3?

    我們想在NTAG5鏈接上執(zhí)行MIFARE SAM AV3, 有沒(méi)有參考文件或演示代碼? 我看過(guò)\"AN12698-MIFARE SAM AV3為NTAG 5, ICODE DNA
    發(fā)表于 11-13 06:13