編者按:反向傳播是一種訓(xùn)練人工神經(jīng)網(wǎng)絡(luò)的常見方法,它能簡化深度模型在計算上的處理方式,是初學(xué)者必須熟練掌握的一種關(guān)鍵算法。對于現(xiàn)代神經(jīng)網(wǎng)絡(luò),通過反向傳播,我們能配合梯度下降大幅提高模型的訓(xùn)練速度,在一周時間內(nèi)就完成以往研究人員可能要耗費兩萬年才能完成的模型。
除了深度學(xué)習(xí),反向傳播算法在許多其他領(lǐng)域也是一個強大的計算工具,從天氣預(yù)報到分析數(shù)值穩(wěn)定性——區(qū)別只在于名稱差異。事實上,這種算法在幾十個不同的領(lǐng)域都有成熟應(yīng)用,無數(shù)研究人員都為這種“反向模式求導(dǎo)”的形式著迷。
從根本上說,無論是深度學(xué)習(xí)還是其他數(shù)值計算環(huán)境,這是一種方便快速計算的方法,也是一個必不可少的計算竅門。
計算圖
談及計算,有人可能又要為煩人的計算公式頭疼了,所以本文用了一種思考數(shù)學(xué)表達式的輕松方法——計算圖。以非常簡單的e=(a+b)×(b+1)為例,從計算角度看它一共有3步操作:兩次求和和一次乘積。為了讓大家對計算圖有更清晰的理解,這里我們把它分開計算,并繪制圖像。
我們可以把這個等式分成3個函數(shù):
在計算圖中,我們把每個函數(shù)連同輸入變量一起放進節(jié)點中。如果當前節(jié)點是另一個節(jié)點的輸入,用帶剪頭的線表示數(shù)據(jù)流向:
這其實是計算機科學(xué)中的一種常見描述方法,尤其是在討論涉及函數(shù)的程序時,它非常有用。此外,現(xiàn)在流行的大多數(shù)深度學(xué)習(xí)開源框架,比如TensorFlow、Caffe、CNTK、Theano等,都采用了計算圖。
仍以之前的例子為例,在計算圖中,我們可以通過設(shè)置輸入變量為特定值來計算表達式。如,我們設(shè)a=2,b=1:
可以得到e=(a+b)×(b+1)=6。
計算圖上的導(dǎo)數(shù)
如果要理解計算圖上的導(dǎo)數(shù),一個關(guān)鍵在于我們?nèi)绾卫斫饷恳粭l帶箭頭的線(下稱“邊”)上的導(dǎo)數(shù)。以之前的連接a節(jié)點和c=a+b節(jié)點的邊為例,如果a對c有影響,那這是個怎么樣的影響?如果a變化了,c會怎么變化?我們稱這為c關(guān)于a的偏導(dǎo)數(shù)。
為了計算圖中的偏導(dǎo)數(shù),我們先來復(fù)習(xí)這兩個求和規(guī)則和乘積規(guī)則:
已知a=2,b=1,那么相應(yīng)的計算圖就是:
現(xiàn)在我們計算出了相鄰兩個節(jié)點的偏導(dǎo)數(shù),如果我想知道不直接相連的節(jié)點是如何相互影響的,你會怎么辦?如果我們以速率為1的速度變化輸入a,那么根據(jù)偏導(dǎo)數(shù)可知,函數(shù)c的變化速率也是1,已知e相對于c的偏導(dǎo)數(shù)是2,那么同樣的,e相對a的變化速率也是2。
計算不直接相連節(jié)點之間偏導(dǎo)數(shù)的一般規(guī)則是計算各路徑偏導(dǎo)數(shù)的和,而同一路徑偏導(dǎo)數(shù)則是各邊偏導(dǎo)數(shù)的乘積,例如,e關(guān)于b的偏導(dǎo)數(shù)就等于:
上式表示了b是如何通過影響函數(shù)c和d來影響函數(shù)e的。
像這種一般的“路徑求和”規(guī)則只是對多元鏈式規(guī)則的不同思考方式。
路徑分解
“路徑求和”的問題在于,如果我們只是簡單粗暴地計算每條可能路徑的偏導(dǎo)數(shù),我們很可能會最后得到一個“爆炸”的和。
如上圖所示,X到Y(jié)有3條路徑,Y到Z也有3條路徑,如果要計算?Z/?X,我們要計算的是3×3=9條路徑的偏導(dǎo)數(shù)的和:
這還只是9條,隨著模型變得越來越復(fù)雜,相應(yīng)的計算復(fù)雜度也會呈指數(shù)級上升。因此比起傻乎乎地一個個求和,我們最好能記起一些小學(xué)數(shù)學(xué)知識,然后把上式轉(zhuǎn)為:
是不是很眼熟?這就是前向傳播算法和反向傳播算法中最基礎(chǔ)的一個偏導(dǎo)數(shù)等式。通過分解路徑,這個式子能更高效地計算總和,雖然長得和求和等式有一定差異,但對于每條邊它確實只計算了一次。
前向模式求導(dǎo)從計算圖的輸入開始,到最后結(jié)束。在每個節(jié)點上,它匯總了所有輸入的路徑,每條路徑代表輸入影響該節(jié)點的一種方式。相加后,我們就能得到輸入對最終結(jié)果的總的影響,也就是偏導(dǎo)數(shù)。
雖然你以前可能沒想過從計算圖的角度來進行理解,但這樣一看,其實前向模式求導(dǎo)和我們剛開始學(xué)微積分時接觸的內(nèi)容差不多。
另一方面,反向模式求導(dǎo)則是從計算圖的最后開始,到輸入結(jié)束。對于每個節(jié)點,它做的是合并所有源自該節(jié)點的路徑。
前向模式求導(dǎo)關(guān)注的是一個輸入如何影響每個節(jié)點,反向模式求導(dǎo)關(guān)注的是每個節(jié)點如何影響最后那一個輸出。換句話說,就是前向模式求導(dǎo)是在把?/?X塞進每個節(jié)點,反向模式求導(dǎo)是在把?Z/?塞進每個節(jié)點。
大功告成
說到現(xiàn)在,你可能會想知道反向模式求導(dǎo)究竟有什么意義。它看起來就是前向模式求導(dǎo)的一個奇怪翻版,其中會有什么優(yōu)勢嗎?
讓我們從之前的那張計算圖開始:
我們先用前向模式求導(dǎo)計算輸入b對各個節(jié)點的影響:
?e/?b=5。我們把這個放一邊,再來看看反向模式求導(dǎo)的情況:
之前我們說反向模式求導(dǎo)關(guān)注的是每個節(jié)點如何影響最后那個輸出,根據(jù)上圖可以發(fā)現(xiàn),圖中偏導(dǎo)數(shù)既有?e/?b的,也有?e/?a的。這是因為這個模型有兩個輸入,而它們都對輸出e產(chǎn)生了影響。也就是說,反向模式求導(dǎo)更能反映全局輸入情況。
如果說這是一個只有兩個輸入的簡單例子,兩種方法都無所謂,那么請想象一個有一百萬個輸入、只有一個輸出的模型。像這樣的模型,我們用前向模式求導(dǎo)要算一百萬次,用反向模式求導(dǎo)只要算1次,這就高下立判了!
在訓(xùn)練神經(jīng)網(wǎng)絡(luò)時,我們把cost(描述網(wǎng)絡(luò)表現(xiàn)好壞的值)視作一個包含各類參數(shù)(描述網(wǎng)絡(luò)行為方式的數(shù)字)的函數(shù)。為了提升模型性能,我們要不斷改變參數(shù)對cost函數(shù)求導(dǎo),以此進行梯度下降。模型的參數(shù)千千萬,但它的輸出只有一個,因此機器學(xué)習(xí)對于反向模式求導(dǎo),也就是反向傳播算法來說是個再適合不過的應(yīng)用領(lǐng)域。
那有沒有一種情況下,前向模式求導(dǎo)能比反向模式求導(dǎo)更好?有的!我們到現(xiàn)在談的都是多輸入單輸出的情形,這時反向更好;如果是一輸入多輸出、多輸入多輸出,前向模式求導(dǎo)速度更快!
這不是太普通了嗎?
當我第一次真正理解反向傳播算法時,我的反應(yīng)是:哦,就是最簡單的鏈式法則!我怎么花了這么久才明白?事實上我也不是唯一出現(xiàn)這種反應(yīng)的人,的確,如果問題是你能從前向模式求導(dǎo)中推出那種更聰明的計算方法,這就沒那么麻煩了。
但我認為這比看起來要困難得多。在反向傳播算法剛發(fā)明的時候,人們其實并沒有十分關(guān)注前饋神經(jīng)網(wǎng)絡(luò)的研究。所以也沒人發(fā)現(xiàn)它的衍生品有利于快速計算。但當大家都知道這種衍生品的好處后,他們又開始反應(yīng)過來:原來它們有這樣的關(guān)系!這之中有一個惡性循環(huán)。
更糟糕的是,在腦子里推一推算法的衍生工具是很普遍的,一旦涉及用它們訓(xùn)練神經(jīng)網(wǎng)絡(luò),這幾乎就等同于洪水猛獸。你肯定會陷入局部最小值!你可能會浪費巨大的計算成本!人們只有在確認這種方法有效后,才會乖乖閉嘴去實踐。
小結(jié)
衍生工具比你想象中的更易于挖掘,也更好用,我希望這是本文為你帶來的主要經(jīng)驗。雖然事實上這個挖掘過程并不容易,但在深度學(xué)習(xí)中領(lǐng)會這一點很重要,換一個角度,我們就能發(fā)現(xiàn)不同的風(fēng)景。同樣的話也適用于其他領(lǐng)域。
還有其他經(jīng)驗嗎?我認為有。
反向傳播算法也是了解數(shù)據(jù)流經(jīng)模型過程的有利“鏡頭”,我們能用它知道為什么有些模型會難以優(yōu)化,如經(jīng)典的遞歸神經(jīng)網(wǎng)絡(luò)中梯度消失的問題。
最后,讀者可以嘗試同時結(jié)合前向傳播和反向傳播兩種算法來進行更有效的計算。如果你真的理解了這兩種算法的技巧,你會發(fā)現(xiàn)其中會有不少有趣的衍生表達式。
-
神經(jīng)網(wǎng)絡(luò)
+關(guān)注
關(guān)注
42文章
4749瀏覽量
100433 -
計算圖
+關(guān)注
關(guān)注
0文章
9瀏覽量
6894 -
深度學(xué)習(xí)
+關(guān)注
關(guān)注
73文章
5471瀏覽量
120903
原文標題:計算圖演算:反向傳播
文章出處:【微信號:jqr_AI,微信公眾號:論智】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
相關(guān)推薦
評論