導(dǎo)讀 使用SAM(銳度感知最小化),優(yōu)化到損失的最平坦的最小值的地方,增強(qiáng)泛化能力。
動(dòng)機(jī)來(lái)自先前的工作,在此基礎(chǔ)上,我們提出了一種新的、有效的方法來(lái)同時(shí)減小損失值和損失的銳度。具體來(lái)說(shuō),在我們的處理過(guò)程中,進(jìn)行銳度感知最小化(SAM),在領(lǐng)域內(nèi)尋找具有均勻的低損失值的參數(shù)。這個(gè)公式產(chǎn)生了一個(gè)最小-最大優(yōu)化問(wèn)題,在這個(gè)問(wèn)題上梯度下降可以有效地執(zhí)行。我們提出的實(shí)證結(jié)果表明,SAM在各種基準(zhǔn)數(shù)據(jù)集上都改善了的模型泛化。
在深度學(xué)習(xí)中,我們使用SGD/Adam等優(yōu)化算法在我們的模型中實(shí)現(xiàn)收斂,從而找到全局最小值,即訓(xùn)練數(shù)據(jù)集中損失較低的點(diǎn)。但等幾種研究表明,許多網(wǎng)絡(luò)可以很容易地記住訓(xùn)練數(shù)據(jù)并有能力隨時(shí)overfit,為了防止這個(gè)問(wèn)題,增強(qiáng)泛化能力,谷歌研究人員發(fā)表了一篇新論文叫做Sharpness Awareness Minimization,在CIFAR10上以及其他的數(shù)據(jù)集上達(dá)到了最先進(jìn)的結(jié)果。
在本文中,我們將看看為什么SAM可以實(shí)現(xiàn)更好的泛化,以及我們?nèi)绾卧赑ytorch中實(shí)現(xiàn)SAM。
SAM的原理是什么?
在梯度下降或任何其他優(yōu)化算法中,我們的目標(biāo)是找到一個(gè)具有低損失值的參數(shù)。但是,與其他常規(guī)的優(yōu)化方法相比,SAM實(shí)現(xiàn)了更好的泛化,它將重點(diǎn)放在領(lǐng)域內(nèi)尋找具有均勻的低損失值的參數(shù)(而不是只有參數(shù)本身具有低損失值)上。
由于計(jì)算鄰域參數(shù)而不是計(jì)算單個(gè)參數(shù),損失超平面比其他優(yōu)化方法更平坦,這反過(guò)來(lái)增強(qiáng)了模型的泛化。
(左))用SGD訓(xùn)練的ResNet收斂到的一個(gè)尖銳的最小值。(右)用SAM訓(xùn)練的相同的ResNet收斂到的一個(gè)平坦的最小值。
注意:SAM不是一個(gè)新的優(yōu)化器,它與其他常見(jiàn)的優(yōu)化器一起使用,比如SGD/Adam。
在Pytorch中實(shí)現(xiàn)SAM
在Pytorch中實(shí)現(xiàn)SAM非常簡(jiǎn)單和直接
import torch
class SAM(torch.optim.Optimizer):
def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
assert rho 》= 0.0, f“Invalid rho, should be non-negative: {rho}”
defaults = dict(rho=rho, **kwargs)
super(SAM, self).__init__(params, defaults)
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
self.param_groups = self.base_optimizer.param_groups
@torch.no_grad()
def first_step(self, zero_grad=False):
grad_norm = self._grad_norm()
for group in self.param_groups:
scale = group[“rho”] / (grad_norm + 1e-12)
for p in group[“params”]:
if p.grad is None: continue
e_w = p.grad * scale.to(p)
p.add_(e_w) # climb to the local maximum “w + e(w)”
self.state[p][“e_w”] = e_w
if zero_grad: self.zero_grad()
@torch.no_grad()
def second_step(self, zero_grad=False):
for group in self.param_groups:
for p in group[“params”]:
if p.grad is None: continue
p.sub_(self.state[p][“e_w”]) # get back to “w” from “w + e(w)”
self.base_optimizer.step() # do the actual “sharpness-aware” update
if zero_grad: self.zero_grad()
def _grad_norm(self):
shared_device = self.param_groups[0][“params”][0].device # put everything on the same device, in case of model parallelism
norm = torch.norm(
torch.stack([
p.grad.norm(p=2).to(shared_device)
for group in self.param_groups for p in group[“params”]
if p.grad is not None
]),
p=2
)
return norm
代碼取自非官方的Pytorch實(shí)現(xiàn)。
代碼解釋:
首先,我們從Pytorch繼承優(yōu)化器類來(lái)創(chuàng)建一個(gè)優(yōu)化器,盡管SAM不是一個(gè)新的優(yōu)化器,而是在需要繼承該類的每一步更新梯度(在基礎(chǔ)優(yōu)化器的幫助下)。
該類接受模型參數(shù)、基本優(yōu)化器和rho, rho是計(jì)算最大損失的鄰域大小。
在進(jìn)行下一步之前,讓我們先看看文中提到的偽代碼,它將幫助我們?cè)跊](méi)有數(shù)學(xué)的情況下理解上述代碼。
正如我們?cè)谟?jì)算第一次反向傳遞后的偽代碼中看到的,我們計(jì)算epsilon并將其添加到參數(shù)中,這些步驟是在上述python代碼的方法first_step中實(shí)現(xiàn)的。
現(xiàn)在在計(jì)算了第一步之后,我們必須回到之前的權(quán)重來(lái)計(jì)算基礎(chǔ)優(yōu)化器的實(shí)際步驟,這些步驟在函數(shù)second_step中實(shí)現(xiàn)。
函數(shù)_grad_norm用于返回矩陣向量的norm,即偽代碼的第10行
在構(gòu)建這個(gè)類后,你可以簡(jiǎn)單地使用它為你的深度學(xué)習(xí)項(xiàng)目通過(guò)以下的訓(xùn)練函數(shù)片段。
from sam import SAM
。。.
model = YourModel()
base_optimizer = torch.optim.SGD # define an optimizer for the “sharpness-aware” update
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)
。。.
for input, output in data:
# first forward-backward pass
loss = loss_function(output, model(input)) # use this loss for any training statistics
loss.backward()
optimizer.first_step(zero_grad=True)
# second forward-backward pass
loss_function(output, model(input)).backward() # make sure to do a full forward pass
optimizer.second_step(zero_grad=True)
。。.
總結(jié)
雖然SAM的泛化效果較好,但是這種方法的主要缺點(diǎn)是,由于前后兩次計(jì)算銳度感知梯度,需要花費(fèi)兩倍的訓(xùn)練時(shí)間。除此之外,SAM還在最近發(fā)布的NFNETS上證明了它的效果,這是ImageNet目前的最高水平,在未來(lái),我們可以期待越來(lái)越多的論文利用這一技術(shù)來(lái)實(shí)現(xiàn)更好的泛化。
英文原文:https://pub.towardsai.net/we-dont-need-to-worry-about-overfitting-anymore-9fb31a154c81
編輯:lyn
-
SAM
+關(guān)注
關(guān)注
0文章
111瀏覽量
33480 -
深度學(xué)習(xí)
+關(guān)注
關(guān)注
73文章
5466瀏覽量
120894 -
pytorch
+關(guān)注
關(guān)注
2文章
802瀏覽量
13116
原文標(biāo)題:【過(guò)擬合】再也不用擔(dān)心過(guò)擬合的問(wèn)題了
文章出處:【微信號(hào):zenRRan,微信公眾號(hào):深度學(xué)習(xí)自然語(yǔ)言處理】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論