TinyBERT 是華為不久前提出的一種蒸餾 BERT 的方法,本文梳理了 TinyBERT 的模型結(jié)構,探索了其在不同業(yè)務上的表現(xiàn),證明了 TinyBERT 對復雜的語義匹配任務來說是一種行之有效的壓縮手段。
作者:chenchenliu&winsechang,騰訊 PCG 內(nèi)容挖掘工程師
來源:騰訊技術工程微信號
一、簡介
在 NLP 領域,BERT 的強大毫無疑問,但由于模型過于龐大,單個樣本計算一次的開銷動輒上百毫秒,很難應用到實際生產(chǎn)中。TinyBERT 是華為、華科聯(lián)合提出的一種為基于 transformer 的模型專門設計的知識蒸餾方法,模型大小不到 BERT 的 1/7,但速度提高了 9 倍,而且性能沒有出現(xiàn)明顯下降。目前,該論文已經(jīng)提交機器學習頂會 ICLR 2020。本文復現(xiàn)了 TinyBERT 的結(jié)果,證明了 Tiny BERT 在速度提高的同時,對復雜的語義匹配任務,性能沒有顯著下降。
目前主流的幾種蒸餾方法大概分成利用 transformer 結(jié)構蒸餾、利用其它簡單的結(jié)構比如 BiLSTM 等蒸餾。由于 BiLSTM 等結(jié)構簡單,且一般是用 BERT 最后一層的輸出結(jié)果進行蒸餾,不能學到 transformer 中間層的信息,對于復雜的語義匹配任務,效果有點不盡人意。
基于 transformer 結(jié)構的蒸餾方法目前比較出名的有微軟的 BERT-PKD (Patient Knowledge Distillation for BERT),huggingface 的 DistilBERT,以及本篇文章講的 TinyBERT。他們的基本思路都是減少 transformer encoding 的層數(shù)和 hidden size 大小,實現(xiàn)細節(jié)上各有不同,主要差異體現(xiàn)在 loss 的設計上。
二、模型實現(xiàn)細節(jié)
整個 TinyBERT 的 loss 設計分為三部分:
1. Embedding-layer Distillation
其中:
分別代表 student 網(wǎng)絡的 embedding 和 teacher 網(wǎng)絡的 embedding. 其中 l 代表 sequence length, d0 代表 student embedding 維度, d 代表 teacher embedding 維度。由于 student 網(wǎng)絡的 embedding 層通常較 teacher 會變小以獲得更小的模型和加速,所以 We 是一個 d 0×d 維的可訓練的線性變換矩陣,把 student 的 embedding 投影到 teacher embedding 所在的空間。最后再算 MSE,得到 embedding loss.
2. Transformer-layer Distillation
TinyBERT 的 transformer 蒸餾采用隔 k 層蒸餾的方式。舉個例子,teacher BERT 一共有 12 層,若是設置 student BERT 為 4 層,就是每隔 3 層計算一個 transformer loss. 映射函數(shù)為 g(m) = 3 * m, m 為 student encoder 層數(shù)。具體對應為 student 第 1 層 transformer 對應 teacher 第 3 層,第 2 層對應第 6 層,第 3 層對應第 9 層,第 4 層對應第 12 層。每一層的 transformer loss 又分為兩部分組成,attention based distillation 和 hidden states based distillation.
2.1 Attention based loss
其中,
h 代表 attention 的頭數(shù),l 代表輸入長度,
代表 student 網(wǎng)絡第 i 個 attention 頭的 attention score 矩陣,
代表 teacher 網(wǎng)絡第 i 個 attention 頭的 attention score 矩陣。這個 loss 是受到斯坦福和 Facebook 聯(lián)合發(fā)表的論文,What Does BERT Look At? An Analysis of BERT’s Attention 的啟發(fā)。這篇論文研究了 attention 權重到底學到了什么,實驗發(fā)現(xiàn)與語義還有語法相關的詞比如第一個動詞賓語,第一個介詞賓語,以及[CLS], [SEP], 逗號等 token,有很高的注意力權重。為了確保這部分信息能被 student 網(wǎng)絡學到,TinyBERT 在 loss 設計中加上了 student 和 teacher 的 attention matrix 的 MSE。這樣語言知識可以很好的從 teacher BERT 轉(zhuǎn)移到 student BERT.
2.2 hidden states based distillation
其中,
分別是 student transformer 和 teacher transformer 的隱層輸出。和 embedding loss 同理,
投影到 Ht 所在的空間。
3. Prediction-Layer Distillation
其中 t 是 temperature value,暫時設為 1.除了模仿中間層的行為外,這一層用來模擬 teacher 網(wǎng)絡在 predict 層的表現(xiàn)。具體來說,這一層計算了 teacher 輸出的概率分布和 student 輸出的概率分布的 softmax 交叉熵。這一層的實現(xiàn)和具體任務相關,我們的兩個實驗分別采取了 BERT 原生的 masked language model loss + next sentence loss 和單任務的 classification softmax cross-entropy.
另外,值得一提的是 prediction loss 有很多變化。在 TinyBERT 中,這個 loss 是 teacher BERT 預測的概率和 student BERT 預測概率的 softmax 交叉熵,在 BERT-PKD 模型中,這個 loss 是 teacher BERT 和 student BERT 的交叉熵和 student BERT 和 hard target( one-hot)的交叉熵的加權平均。我們在業(yè)務中有試過直接用 hard target loss,效果比使用 teacher student softmax 交叉熵下降 5-6 個點。因為 softmax 比 one-hot 編碼了更多概率分布的信息。并且實驗中,softmax cross-entropy loss 容易發(fā)生不收斂的情況,把 softmax 交叉熵改成 MSE, 收斂效果變好,但泛化效果變差。這是因為使用 softmax cross-entropy 需要學到整個概率分布,更難收斂,因為擬合了 teacher BERT 的概率分布,有更強的泛化性。MSE 對極值敏感,收斂的更快,但泛化效果不如前者。
所以總結(jié)一下,loss 的計算公式為:
其中,
三、實驗
TinyBERT 論文中提出了兩階段學習框架,比較新穎。類似于原生的 BERT 先 pre-train, 根據(jù)具體任務再 fine-tine, TinyBERT 先在 general domain 數(shù)據(jù)集上用未經(jīng)微調(diào)的 BERT 充當教師蒸餾出一個 base 模型,在此基礎上,具體任務通過數(shù)據(jù)增強,利用微調(diào)后的 BERT 再進行重新執(zhí)行蒸餾。
這種兩階段的方法給 TinyBERT 提供了像 BERT 一樣的泛化能力。不過為了快速得到實驗結(jié)果,并且論文中的控制變量實驗顯示 general 的蒸餾對各項下游任務的影響較小,我們此次選擇直接用 fine-tune 過的 teacher BERT,蒸餾得到 student BERT.
所以我們蒸餾 TinyBERT 的流程是:
制作任務相關數(shù)據(jù)集;
fine-tune teacher BERT;
固定 teacher BERT 參數(shù),蒸餾得到 TinyBERT.
關于實驗結(jié)果,先上 TinyBERT 論文中的結(jié)論:
可以看到 TinyBERT 表現(xiàn)優(yōu)異。在 GLUE 上,相較于完整的 BERT,性能下降 3 個點,但是推理性能卻得到了巨大提升,快了 9 倍多。
我們在自己的業(yè)務上,也用 TinyBERT 得到了相似的結(jié)果。
3.1 文章連貫性特征任務
做這個特征的目的是為了過濾東拼西湊或者機器生成前后沒有邏輯的文章。由于語義的復雜性還有語義的轉(zhuǎn)移,這個任務和語義相似度任務略有不同,文章的上下句之間語義會有不同。在這個背景下,實驗過 DSSM, Match-Pyramid 等模型,表現(xiàn)效果較差。
由于 BERT 能學到豐富的語義,這個任務目前采用 BERT 的 next sentence 任務較為合適。但是一旦文章很長,原生 BERT 需要算 1 秒甚至更久,這樣的速度是不能接受的。
TinyBERT 在不同實驗參數(shù)下的表現(xiàn)如下:
可以看到 4 層 encoder 的 TinyBERT 在 next sentence 任務下準確率較 BERT base 準確率下降了不到 3 個點,在 mlm 任務上下降較多。在 CPU 上,TinyBERT 相較于 base 速度獲得了將近 8 倍的提升。
3.2 問答 FAQ 任務
業(yè)務場景:為用戶的 query 匹配最接近的 question,將其 answer 返回,是一個 Query-question 語義匹配任務。
下面是蒸餾到兩層 encoder 的 TinyBERT 結(jié)果:
可以看到 ACC 損失 3 個點,AUC 損失 4 個點,取得了不錯的效果。
四、總結(jié)
我們證明了 TinyBERT 作為一種蒸餾方法,能有效的提取 BERT transformer 結(jié)構中豐富的語意信息,在不犧牲性能的情況下,速度能獲得 8 到 9 倍的提升。下一步可能會嘗試蒸餾一個 general 的 TinyBERT base。
更多騰訊AI相關技術干貨,請關注專欄騰訊技術工程
審核編輯 黃昊宇
-
人工智能
+關注
關注
1789文章
46652瀏覽量
237071 -
nlp
+關注
關注
1文章
484瀏覽量
21987
發(fā)布評論請先 登錄
相關推薦
評論