0
  • 聊天消息
  • 系統(tǒng)消息
  • 評論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

數(shù)據(jù)類別不均衡問題的分類及解決方式

深度學(xué)習(xí)自然語言處理 ? 來源:PaperWeekly ? 作者:PaperWeekly ? 2022-07-08 14:51 ? 次閱讀

數(shù)據(jù)類別不均衡問題應(yīng)該是一個極常見又頭疼的的問題了。最近在工作中也是碰到這個問題,花了些時間梳理并實(shí)踐了類別不均衡問題的解決方式,主要實(shí)踐了“魔改”loss(focal loss, GHM loss, dice loss 等),整理如下。

所有的 Loss 實(shí)踐代碼在這里:

https://github.com/shuxinyin/NLP-Loss-Pytorch

數(shù)據(jù)不均衡問題也可以說是一個長尾問題,但長尾那部分?jǐn)?shù)據(jù)往往是重要且不能被忽略的,它不僅僅是分類標(biāo)簽下樣本數(shù)量的不平衡,實(shí)質(zhì)上也是難易樣本的不平衡。

解決不均衡問題一般從兩方面入手:

數(shù)據(jù)層面:重采樣,使得參與迭代計算的數(shù)據(jù)是均衡的;

模型層面:重加權(quán),修改模型的 loss,在 loss 計算上,加大對少樣本的 loss 獎勵。

1. 數(shù)據(jù)層面的重采樣

關(guān)于數(shù)據(jù)層面的重采樣,方式都是通過采樣,重新構(gòu)造數(shù)據(jù)分布,使得數(shù)據(jù)平衡。一般常用的有三種:

欠采樣;

過采樣;

SMOTE。

1. 欠采樣:指某類別下數(shù)據(jù)較多,則只采取部分?jǐn)?shù)據(jù),直接拋棄一些數(shù)據(jù),這種方式太簡單粗暴,擬合出來的模型的偏差大,泛化性能較差;

2. 過采樣:這種方式與欠采樣相反,某類別下數(shù)據(jù)較少,進(jìn)行重復(fù)采樣,達(dá)到數(shù)據(jù)平衡。因為這些少的數(shù)據(jù)反復(fù)迭代計算,會使得模型產(chǎn)生過擬合的現(xiàn)象。

3. SMOTE:一種近鄰插值,可以降低過擬合風(fēng)險,但它是適用于回歸預(yù)測場景下,而 NLP 任務(wù)一般是離散的情況。

這幾種方法單獨(dú)使用會或多或少造成數(shù)據(jù)的浪費(fèi)或重,一般會與 ensemble 方式結(jié)合使用,sample 多份數(shù)據(jù),訓(xùn)練出多個模型,最后綜合。

但以上幾種方式在工程實(shí)踐中往往是少用的,一是因為數(shù)真實(shí)據(jù)珍貴,二也是 ensemble 的方式部署中資源消耗大,沒法接受。因此,就集中看下重加權(quán) loss 改進(jìn)的部分。

2. 模型層面的重加權(quán)

重加權(quán)主要指的是在 loss 計算階段,通過設(shè)計 loss,調(diào)整類別的權(quán)值對 loss 的貢獻(xiàn)。比較經(jīng)典的 loss 改進(jìn)應(yīng)該是 Focal Loss, GHM Loss, Dice Loss。

2.1 Focal Loss

Focal Loss 是一種解決不平衡問題的經(jīng)典 loss,基本思想就是把注意力集中于那些預(yù)測不準(zhǔn)的樣本上。

何為預(yù)測不準(zhǔn)的樣本?比如正樣本的預(yù)測值小于 0.5 的,或者負(fù)樣本的預(yù)測值大于 0.5 的樣本。再簡單點(diǎn),就是當(dāng)正樣本預(yù)測值》0.5 時,在計算該樣本的 loss 時,給它一個小的權(quán)值,反之,正樣本預(yù)測值《0.5 時,給它一個大的權(quán)值。同理,對負(fù)樣本時也是如此。

以二分類為例,一般采用交叉熵作為模型損失。

其中 是真實(shí)標(biāo)簽, 是預(yù)測值,在此基礎(chǔ)又出來了一個權(quán)重交叉熵,即用一個超參去緩解上述這種影響,也就是下式。

接下來,看下 Focal Loss 是怎么做到集中關(guān)注預(yù)測不準(zhǔn)的樣本?

在交叉熵 loss 基礎(chǔ)上,當(dāng)正樣本預(yù)測值 大于 0.5 時,需要給它的 loss 一個小的權(quán)重值 ,使其對總 loss 影響小,反之正樣本預(yù)測值 小于 0.5,給它的 loss 一個大的權(quán)重值。為滿足以上要求,則 增大時, 應(yīng)減小,故剛好 可滿足上述要求。

因此加上注意參數(shù) ,得到 Focal Loss 的二分類情況:

加上調(diào)節(jié)系數(shù) ,F(xiàn)ocal Loss 推廣到多分類的情況:

其中 為第 t 類預(yù)測值,,試驗中效果最佳時,。

代碼的實(shí)現(xiàn)也是比較簡潔的。

def __init__(self, num_class, alpha=None, gamma=2, reduction=‘mean’): super(MultiFocalLoss, self).__init__() self.gamma = gamma 。..。..

def forward(self, logit, target): alpha = self.alpha.to(logit.device) prob = F.softmax(logit, dim=1)

ori_shp = target.shape target = target.view(-1, 1)

prob = prob.gather(1, target).view(-1) + self.smooth # avoid nan logpt = torch.log(prob)

alpha_weight = alpha[target.squeeze().long()] loss = -alpha_weight * torch.pow(torch.sub(1.0, prob), self.gamma) * logpt

if self.reduction == ‘mean’: loss = loss.mean()

return loss

2.2 GHM Loss

上面的 Focal Loss 注重了對 hard example 的學(xué)習(xí),但不是所有的 hard example 都值得關(guān)注,有一些 hard example 很可能是離群點(diǎn),這種離群點(diǎn)當(dāng)然是不應(yīng)該讓模型關(guān)注的。

GHM (gradient harmonizing mechanism) 是一種梯度調(diào)和機(jī)制,GHM Loss 的改進(jìn)思想有兩點(diǎn):1)就是在使模型繼續(xù)保持對 hard example 關(guān)注的基礎(chǔ)上,使模型不去關(guān)注這些離群樣本;2)另外 Focal Loss 中, 的值分別由實(shí)驗經(jīng)驗得出,而一般情況下超參 是互相影響的,應(yīng)當(dāng)共同進(jìn)行實(shí)驗得到。

Focal Loss 中通過調(diào)節(jié)置信度 ,當(dāng)正樣本中模型的預(yù)測值 較小時,則乘上(1-p),給一個大的 loss 值使得模型關(guān)注這種樣本。于是 GHM Loss 在此基礎(chǔ)上,規(guī)定了一個置信度范圍 ,具體一點(diǎn),就是當(dāng)正樣本中模型的預(yù)測值為 較小時,要看這個 多小,若是 ,這種樣本可能就是離群點(diǎn),就不注意它了。

于是 GHM Loss 首先規(guī)定了一個梯度模長 :

其中, 是模型預(yù)測概率值, 是 ground-truth 的標(biāo)簽值,這里以二分類為例,取值為 0 或 1??砂l(fā)現(xiàn), 表示檢測的難易程度, 越大則檢測難度越大。

GHM Loss 的思想是,不要關(guān)注那些容易學(xué)的樣本,也不要關(guān)注那些離群點(diǎn)特別難分的樣本。所以問題就轉(zhuǎn)為我們需要尋找一個變量去衡量這個樣本是不是這兩種,這個變量需滿足當(dāng) 值大時,它要小,從而進(jìn)行抑制,當(dāng) 值小時,它也要小,進(jìn)行抑制。于是文中就引入了梯度密度:

表明了樣本 1~N 中,梯度模長分布在 范圍內(nèi)的樣本個數(shù), 代表了 區(qū)間的長度,因此梯度密度 GD(g) 的物理含義是:單位梯度模長 部分的樣本個數(shù)。

在此基礎(chǔ)上,還需要一個前提,那就是處于 值小與大的樣本(也就是易分樣本與難分樣本)的數(shù)量遠(yuǎn)多于中間值樣本,此時 GD 才可以滿足上述變量的要求。

此時,對于每個樣本,把交叉熵 CE×該樣本梯度密度的倒數(shù),就得到 GHM Loss。

這里附上邏輯的代碼,完整的可以上文章首尾倉庫查看。

class GHM_Loss(nn.Module): def __init__(self, bins, alpha): super(GHM_Loss, self).__init__() self._bins = bins self._alpha = alpha self._last_bin_count = None

def _g2bin(self, g): # split to n bins return torch.floor(g * (self._bins - 0.0001)).long()

def forward(self, x, target): # compute value g g = torch.abs(self._custom_loss_grad(x, target)).detach()

bin_idx = self._g2bin(g)

bin_count = torch.zeros((self._bins)) for i in range(self._bins): # 計算落入bins的梯度模長數(shù)量 bin_count[i] = (bin_idx == i).sum().item()

N = (x.size(0) * x.size(1))

if self._last_bin_count is None: self._last_bin_count = bin_count else: bin_count = self._alpha * self._last_bin_count + (1 - self._alpha) * bin_count self._last_bin_count = bin_count

nonempty_bins = (bin_count 》 0).sum().item()

gd = bin_count * nonempty_bins gd = torch.clamp(gd, min=0.0001) beta = N / gd # 計算好樣本的gd值

# 借由binary_cross_entropy_with_logits,gd值當(dāng)作參數(shù)傳入 return F.binary_cross_entropy_with_logits(x, target, weight=beta[bin_idx])

2.3 Dice Loss & DSC Loss

Dice Loss 是來自文章 V-Net 提出的,DSC Loss 是香儂科技的 Dice Loss for Data-imbalanced NLP Tasks。

按照上面的邏輯,看一下 Dice Loss 是怎么演變過來的。Dice Loss 主要來自于 dice coefficient,dice coefficient 是一種用于評估兩個樣本的相似性的度量函數(shù)。

定義是這樣的:取值范圍在 0 到 1 之間,值越大表示越相似。若令 X 是所有模型預(yù)測為正的樣本的集合,Y 為所有實(shí)際上為正類的樣本集合,dice coefficient 可重寫為:

同時,結(jié)合 F1 的指標(biāo)計算公式推一下,可得:

可以動手推一下,就能得到 dice coefficient 是等同 F1 score 的,因此本質(zhì)上 dice loss 是直接優(yōu)化 F1 指標(biāo)的。

上述表達(dá)式是離散的,需要把上述 DSC 表達(dá)式轉(zhuǎn)化為連續(xù)的版本,需要進(jìn)行軟化處理。對單個樣本 x,可以直接定義它的 DSC:

但是當(dāng)樣本為負(fù)樣本時,y1=0,loss 就為 0 了,需要加一個平滑項。

上面有說到 dice coefficient 是一種兩個樣本的相似性的度量函數(shù),上式中,假設(shè)正樣本 p 越大,dice 值越大,說明模型預(yù)測的越準(zhǔn),則應(yīng)該 loss 值越小,因此 dice loss 的就變成了下式這也就是最終 dice loss 的樣子。

為了能得到 focal loss 同樣的功能,讓 dice loss 集中關(guān)注預(yù)測不準(zhǔn)的樣本,可以與 focal loss 一樣加上一個調(diào)節(jié)系數(shù) ,就得到了香儂提出的適用于 NLP 任務(wù)的自調(diào)節(jié) DSC-Loss。

弄明白了原理,看下代碼的實(shí)現(xiàn)。

class DSCLoss(torch.nn.Module):

def __init__(self, alpha: float = 1.0, smooth: float = 1.0, reduction: str = “mean”): super().__init__() self.alpha = alpha self.smooth = smooth self.reduction = reduction

def forward(self, logits, targets): probs = torch.softmax(logits, dim=1) probs = torch.gather(probs, dim=1, index=targets.unsqueeze(1))

probs_with_factor = ((1 - probs) ** self.alpha) * probs loss = 1 - (2 * probs_with_factor + self.smooth) / (probs_with_factor + 1 + self.smooth)

if self.reduction == “mean”: return loss.mean()

總結(jié)

本文主要討論了類別不均衡問題的解決辦法,可分為數(shù)據(jù)層面的重采樣及模型 loss 方面的改進(jìn),如 focal loss, dice loss 等。最后說一下實(shí)踐下來的經(jīng)驗,由于不同數(shù)據(jù)集的數(shù)據(jù)分布特點(diǎn)各有不同,dice loss 以及 GHM loss 會出現(xiàn)些抖動、不穩(wěn)定的情況。當(dāng)不想挨個實(shí)踐的時候,首推 focal loss,dice loss。

以上所有 Loss 的代碼僅為邏輯參考,完整的代碼及相關(guān)參考論文都在:

https://github.com/shuxinyin/NLP-Loss-Pytorch

審核編輯:郭婷

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問題,請聯(lián)系本站處理。 舉報投訴
  • 數(shù)據(jù)
    +關(guān)注

    關(guān)注

    8

    文章

    6808

    瀏覽量

    88743
  • 代碼
    +關(guān)注

    關(guān)注

    30

    文章

    4722

    瀏覽量

    68234

原文標(biāo)題:類別不均衡問題之loss大集合:focal loss, GHM loss, dice loss 等等

文章出處:【微信號:zenRRan,微信公眾號:深度學(xué)習(xí)自然語言處理】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。

收藏 人收藏

    評論

    相關(guān)推薦

    負(fù)載均衡是什么意思?盤點(diǎn)常見的三種方式

    負(fù)載均衡是什么意思?負(fù)載均衡(LoadBalancing)是一種計算機(jī)技術(shù),主要用于在多個計算資源(如服務(wù)器、虛擬機(jī)、容器等)中分配和管理負(fù)載,以達(dá)到優(yōu)化資源使用、最大化吞吐率、最小化響應(yīng)時間,并
    的頭像 發(fā)表于 09-29 14:30 ?198次閱讀

    信道均衡的原理和分類介紹

    一、信道均衡的基本原理 信道均衡的基本目標(biāo)是對信道或整個傳輸系統(tǒng)的頻率響應(yīng)進(jìn)行補(bǔ)償,以減輕或消除由多徑傳播引起的碼間串?dāng)_(ISI)。在數(shù)字通信中,ISI會嚴(yán)重影響接收端的信號質(zhì)量,導(dǎo)致數(shù)據(jù)傳輸錯誤
    的頭像 發(fā)表于 09-10 10:49 ?893次閱讀
    信道<b class='flag-5'>均衡</b>的原理和<b class='flag-5'>分類</b>介紹

    調(diào)速器的主要分類和運(yùn)轉(zhuǎn)方式

    調(diào)速器作為一種用于控制發(fā)動機(jī)轉(zhuǎn)速的裝置,在機(jī)械設(shè)備中起著至關(guān)重要的作用。其分類和運(yùn)轉(zhuǎn)方式多種多樣,以下是對調(diào)速器主要分類和運(yùn)轉(zhuǎn)方式的詳細(xì)解析。
    的頭像 發(fā)表于 08-25 16:42 ?1507次閱讀

    終結(jié)難題的發(fā)明(一):蓄電池均衡系統(tǒng)及其控制方法

    ,大家要明白,動力電池和手機(jī)電池不一樣!2)不均衡的成因,a 效率性差異;b 容量性差異,c 自放電差異。這么多電芯串聯(lián)在一起工作,每個電池是否處于一致的狀態(tài)呢?這就是電池的均衡問題。電池的均衡,說
    發(fā)表于 07-11 13:18

    電機(jī)的分類有哪些

    電機(jī),作為現(xiàn)代工業(yè)、交通、家電等領(lǐng)域中不可或缺的重要設(shè)備,其種類繁多,功能各異。從工作電源、結(jié)構(gòu)和工作原理、起動與運(yùn)行方式、用途、轉(zhuǎn)子結(jié)構(gòu)以及運(yùn)轉(zhuǎn)速度等多個維度出發(fā),電機(jī)可以被劃分為多個類別。本文將詳細(xì)探討電機(jī)的分類,并結(jié)合相關(guān)
    的頭像 發(fā)表于 06-25 14:57 ?972次閱讀

    PLC總線的分類方式

    、高效傳輸。隨著工業(yè)自動化技術(shù)的不斷發(fā)展,PLC總線也逐漸呈現(xiàn)出多樣化的分類,以適應(yīng)不同應(yīng)用場景的需求。本文將詳細(xì)介紹PLC總線的分類方式,并結(jié)合實(shí)際案例和數(shù)據(jù)進(jìn)行深入分析。
    的頭像 發(fā)表于 06-13 17:48 ?947次閱讀

    視頻網(wǎng)站服務(wù)器的四種負(fù)載均衡技術(shù)

    域名并返回多個服務(wù)器的IP地址列表,客戶端會根據(jù)DNS返回的IP地址進(jìn)行請求。這種方式簡單易用,但無法直接控制請求的分配,且存在DNS緩存問題。另外,由于DNS解析的緩存時間,可能導(dǎo)致負(fù)載不均衡。 2、硬件負(fù)載均衡器:硬件負(fù)載
    的頭像 發(fā)表于 04-01 17:36 ?559次閱讀

    深入理解 AFE 的用法:實(shí)現(xiàn)BMS?均衡功能

    。 ? ? AFE? 提供的均衡接口主要是被動均衡,即通過電阻放電;它可以支持內(nèi)部均衡與外部均衡兩種(如下圖,來自于 ADI 的 LTC6813);兩種
    的頭像 發(fā)表于 03-28 15:03 ?4527次閱讀
    深入理解 AFE 的用法:實(shí)現(xiàn)BMS?<b class='flag-5'>均衡</b>功能

    機(jī)器學(xué)習(xí)多分類任務(wù)深度解析

    一對其余其實(shí)更加好理解,每次將一個類別作為正類,其余類別作為負(fù)類。此時共有(N個分類器)。在測試的時候若僅有一個分類器預(yù)測為正類,則對應(yīng)的類別
    發(fā)表于 03-18 10:58 ?1464次閱讀
    機(jī)器學(xué)習(xí)多<b class='flag-5'>分類</b>任務(wù)深度解析

    信道均衡有哪些實(shí)現(xiàn)方式

    信道均衡的實(shí)現(xiàn)方式主要包括線性自動應(yīng)均衡、盲均衡和半盲均衡等。這些方法各有特點(diǎn),選擇哪種方法取決于具體的應(yīng)用場景和性能要求。例如,如果信道變
    的頭像 發(fā)表于 03-02 14:05 ?1333次閱讀

    功放分幾種類型,功放常見分類方式

    功放(Power Amplifier)作為音頻設(shè)備中不可或缺的重要組成部分,廣泛應(yīng)用于音響系統(tǒng)、電視機(jī)、電子設(shè)備等多個領(lǐng)域。功放按照不同的分類方式可以分成多種類型,本文將細(xì)致地介紹功放的常見分類
    的頭像 發(fā)表于 02-23 10:58 ?4309次閱讀

    光模塊類別的5種分類詳解

    光模塊類別的5種分類詳解? 光模塊是光通信領(lǐng)域中非常重要的組件之一,它用于將光信號轉(zhuǎn)換為電信號或者將電信號轉(zhuǎn)換為光信號,在光纖通信、數(shù)據(jù)中心、廣域網(wǎng)和市區(qū)網(wǎng)等領(lǐng)域中廣泛應(yīng)用。根據(jù)功能和使用場
    的頭像 發(fā)表于 12-27 10:50 ?1759次閱讀

    如何確定適合的負(fù)載均衡比例

    其影響以及相關(guān)策略。 什么是負(fù)載均衡比例? 在網(wǎng)絡(luò)中,路由器通常連接著多個網(wǎng)絡(luò)設(shè)備和服務(wù)器。當(dāng)網(wǎng)絡(luò)流量過大時,使用單個設(shè)備處理這些數(shù)據(jù)可能會導(dǎo)致性能下降或網(wǎng)絡(luò)擁塞。為了解決這個問題,負(fù)載均衡技術(shù)被引入,它將
    的頭像 發(fā)表于 12-15 10:36 ?1421次閱讀

    路由器負(fù)載均衡怎么配置

    路由器負(fù)載均衡是一種重要的網(wǎng)絡(luò)技術(shù),它能夠?qū)⒍鄠€網(wǎng)絡(luò)連接的流量分配到多個路由器上,以提高網(wǎng)絡(luò)的性能和穩(wěn)定性。本文將詳細(xì)介紹路由器負(fù)載均衡的配置方法,包括負(fù)載均衡的實(shí)現(xiàn)方式、配置步驟和注
    的頭像 發(fā)表于 12-13 11:17 ?3034次閱讀

    電力系統(tǒng)運(yùn)行方式分類

    一、按聯(lián)接關(guān)系分類 1、并聯(lián)運(yùn)行 各電氣設(shè)備或線路同時帶電運(yùn)行,且在電氣上符合并聯(lián)關(guān)系的運(yùn)行方式。 2、分列運(yùn)行 各電氣設(shè)備或線路同時帶電運(yùn)行,但在電氣上不構(gòu)成并聯(lián)關(guān)系的運(yùn)行方式。 3、一臺(路
    的頭像 發(fā)表于 11-24 15:00 ?5084次閱讀