理論介紹
相比于訓(xùn)練后量化方法,將量化過(guò)程插入到訓(xùn)練中可以彌補(bǔ)量化產(chǎn)生的誤差,但是帶來(lái)的問(wèn)題可能是增加了訓(xùn)練的時(shí)間。在tansformer的量化實(shí)現(xiàn)中,我們采用了訓(xùn)練中量化的方法,在網(wǎng)絡(luò)前向傳輸中,對(duì)權(quán)重等參數(shù)進(jìn)行線性量化。反向傳播中,對(duì)scale和權(quán)重參數(shù)的求導(dǎo)采用Hinton的strait-through estimator的方式。在CPU上訓(xùn)練花費(fèi)了10天的時(shí)間,在這期間又review了最近的量化方法的文章。所以先總結(jié)一下,然后再分析一下transformer量化的結(jié)果。
1) PACT
這是一種實(shí)現(xiàn)對(duì)activation量化的方法,基本思想是通過(guò)訓(xùn)練來(lái)獲得ReLU的一個(gè)clip參數(shù)a。a的動(dòng)態(tài)調(diào)整能夠在減少量化誤差和保證反向傳播有效進(jìn)行之間獲得平衡。PACT重新定義了ReLU過(guò)程如下:
參數(shù)a限定了activation的范圍為[0, a]。然后獲得的激活值y在進(jìn)行線性映射到k bit的表示空間,如下:
在這里[0, a]是y值的一個(gè)限定,a>=y。所以其范圍比y值的實(shí)際范圍要大,這可以對(duì)y的量化誤差有一些彌補(bǔ)。采用strait-through estimator方法計(jì)算其相對(duì)于a的梯度為:
當(dāng)a趨向于無(wú)窮大的時(shí)候,就接近于ReLU函數(shù),所以訓(xùn)練過(guò)程一定是往a增大方向移動(dòng)。通過(guò)在loss中增加a的L2 規(guī)范化可以尋求一個(gè)合適的a值。
2) quantization-aware training
谷歌采用量化和訓(xùn)練分離的方法,在前向計(jì)算使用量化數(shù)據(jù),而在訓(xùn)練的時(shí)候還是浮點(diǎn)訓(xùn)練。量化方法為如下公式:
其中S為scale參數(shù),z是零點(diǎn)偏移,q是量化后參數(shù)。Z值的存在會(huì)導(dǎo)致矩陣或者卷積運(yùn)算中有交叉項(xiàng)。這會(huì)增加一部分加法和乘法項(xiàng)。這在CPU等通用處理器上容易實(shí)現(xiàn),只是一個(gè)時(shí)間復(fù)雜度的問(wèn)題,但是實(shí)際上不利于在FPGA等硬件上實(shí)現(xiàn)。所以FPGA等平臺(tái)的量化一般都讓z值為0。消除交叉項(xiàng)計(jì)算。對(duì)于一個(gè)矩陣乘法,量化導(dǎo)致了scale的組合,比如:
在這里M是浮點(diǎn)數(shù)據(jù),在這里作者對(duì)其又做了一次量化,首先將M數(shù)據(jù)映射到[0.5, 1)空間,然后在使用32bit數(shù)據(jù)來(lái)表達(dá)為整數(shù)。
32bit的表達(dá)能夠降低量化精度。
在量化整個(gè)網(wǎng)絡(luò)的過(guò)程中,作者也提供了一些處理技巧。在進(jìn)行線性量化的時(shí)候,采用了對(duì)稱的量化區(qū)間,比如8bit量化,正常取值范圍在[-128, 127],作者取了對(duì)稱空間[-127, 127]。這樣做的目的和實(shí)現(xiàn)的平臺(tái)有關(guān)。在量化activation的時(shí)候,使用EMA來(lái)處理收集到的數(shù)值范圍,這樣做可以在初始訓(xùn)練中,完全屏蔽掉對(duì)activation的量化,使得訓(xùn)練進(jìn)入到一個(gè)比較穩(wěn)定的狀態(tài)后在進(jìn)行量化。BN是一個(gè)復(fù)雜的計(jì)算,但是可以將其折疊到之前的卷積層和FC層中,如下圖所示:
3) 訓(xùn)練后量化,基于KL發(fā)散性。
基于訓(xùn)練后的量化方法的優(yōu)勢(shì)就是量化花費(fèi)時(shí)間短。在tensorRT中使用了KL發(fā)散性來(lái)描述量化后的數(shù)據(jù)和浮點(diǎn)數(shù)據(jù)之間的信息損失程度。通過(guò)最小化這個(gè)值來(lái)達(dá)到量化后數(shù)據(jù)包含的信息接近浮點(diǎn)數(shù)據(jù)的信息。這種方法的出發(fā)點(diǎn)是,為了保證量化后模型的精度損失較小,應(yīng)該讓量化后的數(shù)據(jù)和原始浮點(diǎn)數(shù)據(jù)表達(dá)的信息最一致。具體的做法是:
對(duì)每層網(wǎng)絡(luò),先收集activation的數(shù)值區(qū)間,這樣就生成一個(gè)activation值的分布;采用不同的量化區(qū)間[a,b]來(lái)對(duì)activation進(jìn)行線性映射,這樣就形成了針對(duì)參數(shù)a和b的多種不同分布,然后找到和原始數(shù)據(jù)分布KL最小的分布,這個(gè)時(shí)候得到的a和b的值就是量化activation時(shí)所采用的threshold值。
Transformer量化結(jié)果
還是決定由簡(jiǎn)入難,先進(jìn)行16bit的量化,量化內(nèi)容包括transformer中的dense層,F(xiàn)C層。對(duì)權(quán)重和數(shù)據(jù)都進(jìn)行16bit的量化,即將量化節(jié)點(diǎn)插入到計(jì)算圖中。梯度采用strait-through estimator來(lái)估計(jì)。對(duì)于embedding,softmax,layer normalization還是使用浮點(diǎn)值。因?yàn)閾?dān)心對(duì)這些的量化可能會(huì)導(dǎo)致精度降低。選擇batch size為256,epoch為20,數(shù)據(jù)集使用英語(yǔ)德語(yǔ)翻譯數(shù)據(jù)集。這個(gè)數(shù)據(jù)集有460萬(wàn)個(gè)句子。在服務(wù)器上使用CPU跑了10天,以下是結(jié)果:
對(duì)比一下github上作者浮點(diǎn)模型的訓(xùn)練結(jié)果:
發(fā)現(xiàn)存在以下問(wèn)題:
1 loss下降很慢,浮點(diǎn)模型在訓(xùn)練到達(dá)5k次的時(shí)候,loss已經(jīng)下降到4了,但是量化的訓(xùn)練loss在5k次的時(shí)候才到5.4。經(jīng)歷了前幾次快速下降之后,后邊更加緩慢。
2 BLEU得分很低,訓(xùn)練了10K次后得分才有0.11。得分低的原因也是loss值很低。
第一次做沒(méi)有什么經(jīng)驗(yàn),猜測(cè)可能有以下幾種原因:
1 對(duì)所有的scale我都使用了常數(shù)2作為初始值,為什么選擇2,并沒(méi)有什么原因,就是隨便選擇的?;蛟S初始值的不當(dāng)導(dǎo)致了loss訓(xùn)練很慢。設(shè)想通過(guò)以下方式來(lái)改進(jìn),先進(jìn)行warmup,通過(guò)計(jì)算參數(shù)的范圍來(lái)計(jì)算出一個(gè)scale值。進(jìn)行了幾輪warmup之后再進(jìn)行量化訓(xùn)練。
2 因?yàn)榭吹絣oss也一直是下降的趨勢(shì),那么猜測(cè)可能是量化訓(xùn)練是比正常訓(xùn)練收斂慢。因?yàn)榱炕瘏?shù)的梯度在參數(shù)超過(guò)閾值會(huì)為0,這個(gè)可能導(dǎo)致梯度更新較慢。
編輯:hfy
-
cpu
+關(guān)注
關(guān)注
68文章
10804瀏覽量
210829 -
Transformer
+關(guān)注
關(guān)注
0文章
139瀏覽量
5968
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論