大家好,生成式大模型的熱度,從今年3月開始已經(jīng)燃了一個(gè)多季度了。在這個(gè)季度中,相信大家肯定看過很多AI產(chǎn)生的有趣內(nèi)容,比如著名的抓捕川普現(xiàn)場與監(jiān)獄風(fēng)云 [現(xiàn)在來看不僅是畫得像而已],AI換聲孫燕姿等。這背后都用到了一個(gè)強(qiáng)大的模型:Diffusion Model。所以,在這個(gè)系列中,我們將從原理到源碼,從基石DDPM到DALLE2,Imagen與Stable Diffusion,通過詳細(xì)的圖例和解說,和大家一起來了解擴(kuò)散模型的奧秘。同時(shí),也會穿插對經(jīng)典的GAN,VAE等模型的解讀,敬請期待~
本篇將和大家一起解讀擴(kuò)散模型的基石:DDPM(Denoising Diffusion Probalistic Models)。擴(kuò)散模型的研究并不始于DDPM,但DDPM的成功對擴(kuò)散模型的發(fā)展起到至關(guān)重要的作用。在這個(gè)系列里我們也會看到,后續(xù)一連串效果驚艷的模型,都是在DDPM的框架上迭代改進(jìn)而來。所以,我把DDPM放在這個(gè)系列的第一篇進(jìn)行講解。
初讀DDPM論文的朋友,可能有以下兩個(gè)痛點(diǎn):
論文花極大篇幅講數(shù)學(xué)推導(dǎo),可是我看不懂。
論文沒有給出模型架構(gòu)圖和詳細(xì)的訓(xùn)練解說,而這是我最關(guān)心的部分。
針對這些痛點(diǎn),DDPM系列將會出如下三篇文章:
DDPM(模型架構(gòu)篇):也就是本篇文章。在閱讀源碼的基礎(chǔ)上,本文繪制了詳細(xì)的DDPM模型架構(gòu)圖,同時(shí)附上關(guān)于模型運(yùn)作流程的詳細(xì)解說。本文不涉及數(shù)學(xué)知識,直觀幫助大家了解DDPM怎么用,為什么好用。
DDPM(人人都能看懂的數(shù)學(xué)推理篇):DDPM的數(shù)學(xué)推理可能是很多讀者頭疼的部分。我嘗試跳出原始論文的推導(dǎo)順序和思路,從更符合大家思維模式的角度入手,把整個(gè)推理流程串成一條完整的邏輯線。同樣,我也會配上大量的圖例,方便大家理解數(shù)學(xué)公式。如果你不擅長數(shù)學(xué)推導(dǎo),這篇文章可以幫助你從直覺上了解DDPM的數(shù)學(xué)有效性;如果你更關(guān)注推導(dǎo)細(xì)節(jié),這篇文章中也有詳細(xì)的推導(dǎo)中間步驟。
DDPM(源碼解讀篇):在前兩篇的基礎(chǔ)上,我們將配合模型架構(gòu)圖,一起閱讀DDPM源碼,并實(shí)操跑一次,觀測訓(xùn)練過程里的中間結(jié)果。
本文目錄如下:
一、DDPM在做一件什么事
二、DDPM訓(xùn)練流程:Diffusion/Denoise Process
三、DDPM的training與sampling
四、圖解DDPM核心模型架構(gòu):UNet
五、文生圖模型的一般公式
六、參考
最后,Megatron源碼解讀的系列沒有停更!只是兩個(gè)系列穿插來寫。
一、DDPM在做一件什么事
假設(shè)你想做一個(gè)以文生圖的模型,你的目的是給一段文字,再隨便給一張圖(比如一張?jiān)肼暎?,這個(gè)模型能幫你產(chǎn)出符合文字描述的逼真圖片,例如:
文字描述就像是一個(gè)指引(guidance),幫助模型去產(chǎn)生更符合語義信息的圖片。但是,畢竟語義學(xué)習(xí)是復(fù)雜的。我們能不能先退一步,先讓模型擁有產(chǎn)生逼真圖片的能力?
比如說,你給模型喂一堆cyperpunk風(fēng)格的圖片,讓模型學(xué)會cyberpunk風(fēng)格的分布信息,然后喂給模型一個(gè)隨機(jī)噪音,就能讓模型產(chǎn)生一張逼真的cyberpunk照片?;蛘呓o模型喂一堆人臉圖片,讓模型產(chǎn)生一張逼真的人臉。同樣,我們也能選擇給訓(xùn)練好的模型喂帶點(diǎn)信息的圖片,比如一張夾雜噪音的人臉,讓模型幫我們?nèi)ピ搿?/p>
具備了產(chǎn)出逼真圖片的能力,模型才可能在下一步中去學(xué)習(xí)語義信息(guidance),進(jìn)一步產(chǎn)生符合人類意圖的圖片。而DDPM的本質(zhì)作用,就是學(xué)習(xí)訓(xùn)練數(shù)據(jù)的分布,產(chǎn)出盡可能符合訓(xùn)練數(shù)據(jù)分布的真實(shí)圖片。所以,它也成為后續(xù)文生圖類擴(kuò)散模型框架的基石。
二、DDPM訓(xùn)練流程
理解DDPM的目的,及其對后續(xù)文生圖的模型的影響,現(xiàn)在我們可以更好來理解DDPM的訓(xùn)練過程了??傮w來說,DDPM的訓(xùn)練過程分為兩步:
Diffusion Process (又被稱為Forward Process)
Denoise Process(又被稱為Reverse Process)
前面說過,DDPM的目的是要去學(xué)習(xí)訓(xùn)練數(shù)據(jù)的分布,然后產(chǎn)出和訓(xùn)練數(shù)據(jù)分布相似的圖片。那怎么“迫使”模型去學(xué)習(xí)呢?
一個(gè)簡單的想法是,我拿一張干凈的圖,每一步(timestep)都往上加一點(diǎn)噪音,然后在每一步里,我都讓模型去找到加噪前圖片的樣子,也就是讓模型學(xué)會去噪。這樣訓(xùn)練完畢后,我再塞給模型一個(gè)純噪聲,它不就能一步步幫我還原出原始圖片的分布了嗎?
一步步加噪的過程,就被稱為Diffusion Process;一步步去噪的過程,就被稱為Denoise Process。我們來詳細(xì)看這兩步。
2.1 Diffusion Process
Diffusion Process的命名受到熱力學(xué)中分子擴(kuò)散的啟發(fā):分子從高濃度區(qū)域擴(kuò)散至低濃度區(qū)域,直至整個(gè)系統(tǒng)處于平衡。加噪過程也是同理,每次往圖片上增加一些噪聲,直至圖片變?yōu)橐粋€(gè)純噪聲為止。整個(gè)過程如下:
如圖所示,我們進(jìn)行了1000步的加噪,每一步我們都往圖片上加入一個(gè)高斯分布的噪聲,直到圖片變?yōu)橐粋€(gè)純高斯分布的噪聲。
我們記:
:總步數(shù)
:每一步產(chǎn)生的圖片。其中為原始圖片,為純高斯噪聲
:為每一步添加的高斯噪聲
:在條件下的概率分布。如果你覺得抽象,可以理解成已知,求
那么根據(jù)以上流程圖,我們有:
根據(jù)公式,為了知道,需要sample好多次噪聲,感覺不太方便,能不能更簡化一些呢?
我們知道隨著步數(shù)的增加,圖片中原始信息含量越少,噪聲越多,我們可以分別給原始圖片和噪聲一個(gè)權(quán)重來計(jì)算:
:一系列常數(shù),類似于超參數(shù),隨著的增加越來越小。
則此時(shí)的計(jì)算可以設(shè)計(jì)成:
現(xiàn)在,我們只需要sample一次噪聲,就可以直接得到了。
接下來,我們再深入一些,其實(shí)并不是我們直接設(shè)定的超參數(shù),它是根據(jù)其它超參數(shù)推導(dǎo)而來,這個(gè)“其它超參數(shù)”指:
:一系列常數(shù),是我們直接設(shè)定的超參數(shù),隨著T的增加越來越大
則和的關(guān)系為:
這樣從原始加噪到加噪,再到加噪,使得轉(zhuǎn)換成的過程,就被稱為重參數(shù)(Reparameterization)。我們會在這個(gè)系列的下一篇(數(shù)學(xué)推導(dǎo)篇)中進(jìn)一步探索這樣做的目的和可行性。在本篇中,大家只需要從直覺上理解它的作用方式即可。
2.2 Denoise Process
Denoise Process的過程與Diffusion Process剛好相反:給定,讓模型能把它還原到。在上文中我們曾用這個(gè)符號來表示加噪過程,這里我們用來表示去噪過程。由于加噪過程只是按照設(shè)定好的超參數(shù)進(jìn)行前向加噪,本身不經(jīng)過模型。但去噪過程是真正訓(xùn)練并使用模型的過程。所以更進(jìn)一步,我們用來表示去噪過程,其中表示模型參數(shù),即:
:用來表示Diffusion Process
:用來表示Denoise Process。
講完符號表示,我們來具體看去噪模型做了什么事。如下圖所示,從第T個(gè)timestep開始,模型的輸入為與當(dāng)前timestep 。模型中蘊(yùn)含一個(gè)噪聲預(yù)測器(UNet),它會根據(jù)當(dāng)前的輸入預(yù)測出噪聲,然后,將當(dāng)前圖片減去預(yù)測出來的噪聲,就可以得到去噪后的圖片。重復(fù)這個(gè)過程,直到還原出原始圖片為止:
你可能想問:
為什么我們的輸入中要包含time_step?
為什么通過預(yù)測噪聲的方式,就能讓模型學(xué)得訓(xùn)練數(shù)據(jù)的分布,進(jìn)而產(chǎn)生逼真的圖片?
第二個(gè)問題的答案我們同樣放在下一篇(數(shù)學(xué)推理篇)中進(jìn)行詳解。而對于第一個(gè)問題,由于模型每一步的去噪都用的是同一個(gè)模型,所以我們必須告訴模型,現(xiàn)在進(jìn)行的是哪一步去噪。因此我們要引入timestep。timestep的表達(dá)方法類似于Transformer中的位置編碼(可以參考這篇文章),將一個(gè)常數(shù)轉(zhuǎn)換為一個(gè)向量,再和我們的輸入圖片進(jìn)行相加。
注意到,UNet模型是DDPM的核心架構(gòu),我們將關(guān)于它的介紹放在本文的第四部分。
到這里為止,如果不考慮整個(gè)算法在數(shù)學(xué)上的有效性,我們已經(jīng)能從直覺上理解擴(kuò)散模型的運(yùn)作流程了。那么,我們就可以對它的訓(xùn)練和推理過程來做進(jìn)一步總結(jié)了。
三、DDPM的Training與Sampling過程
3.1 DDPM Training
上圖給出了DDPM論文中對訓(xùn)練步驟的概述,我們來詳細(xì)解讀它。
前面說過,DDPM模型訓(xùn)練的目的,就是給定time_step和輸入圖片,結(jié)合這兩者去預(yù)測圖片中的噪聲。
我們知道,在重參數(shù)的表達(dá)下,第t個(gè)時(shí)刻的輸入圖片可以表示為:
也就是說,第t個(gè)時(shí)刻sample出的噪聲,就是我們的噪聲真值。而我們預(yù)測出來的噪聲為:
,其中為模型參數(shù),表示預(yù)測出的噪聲和模型相關(guān)。那么易得出我們的loss為:
我們只需要最小化該loss即可。
由于不管對任何輸入數(shù)據(jù),不管對它的任何一步,模型在每一步做的都是去預(yù)測一個(gè)來自高斯分布的噪聲。因此,整個(gè)訓(xùn)練過程可以設(shè)置為:
從訓(xùn)練數(shù)據(jù)中,抽樣出一條(即)
隨機(jī)抽樣出一個(gè)timestep。(即)
隨機(jī)抽樣出一個(gè)噪聲(即)
計(jì)算:
計(jì)算梯度,更新模型,重復(fù)上面過程,直至收斂
上面演示的是單條數(shù)據(jù)計(jì)算loss的過程,當(dāng)然,整個(gè)過程也可以在batch范圍內(nèi)做,batch中單條數(shù)據(jù)計(jì)算loss的方法不變。
3.2 DDPM的Sampling
當(dāng)DDPM訓(xùn)練好之后,我們要怎么用它,怎么評估它的效果呢?
對于訓(xùn)練好的模型,我們從最后一個(gè)時(shí)刻(T)開始,傳入一個(gè)純噪聲(或者是一張加了噪聲的圖片),逐步去噪。根據(jù)x_tx_{t-1}的關(guān)系(上圖的前半部分)。而圖中一項(xiàng),則不是直接推導(dǎo)而來的,是我們?yōu)榱嗽黾油评碇械碾S機(jī)性,而額外增添的一項(xiàng)??梢灶惐扔贕PT中為了增加回答的多樣性,不是選擇概率最大的那個(gè)token,而是在topN中再引入方法進(jìn)行隨機(jī)選擇。
關(guān)于和關(guān)系的詳細(xì)推導(dǎo),我們也放在數(shù)學(xué)推理篇中做解釋。
通過上述方式產(chǎn)生的,我們可以計(jì)算它和真實(shí)圖片分布之間的相似度(FID score:Frechet Inception Distance score)來評估圖片的逼真性。在DDPM論文中,還做了一些有趣的實(shí)驗(yàn),例如通過“插值(interpolation)”方法,先對兩張任意的真實(shí)圖片做Diffusion過程,然后分別給它們的diffusion結(jié)果附不同的權(quán)重(),將兩者diffusion結(jié)果加權(quán)相加后,再做Denoise流程,就可以得到一張很有意思的"混合人臉":
到目前為止,我們已經(jīng)把整個(gè)DDPM的核心運(yùn)作方法講完了。接下來,我們來看DDPM用于預(yù)測噪聲的核心模型:UNet,到底長成什么樣。我在學(xué)習(xí)DDPM的過程中,在網(wǎng)上幾乎找不到關(guān)于DDPM UNet的詳細(xì)模型解說,或者一張清晰的架構(gòu)圖,這給我在源碼閱讀過程中增加了難度。所以在讀完源碼并進(jìn)行實(shí)操訓(xùn)練后,我干脆自己畫一張出來,也借此幫助自己更好理解DDPM。
四、DDPM中的Unet架構(gòu)
UNet模型最早提出時(shí),是用于解決醫(yī)療影響診斷問題的。總體上說,它分成兩個(gè)部分:
Encoder
Decoder
在Encoder部分中,UNet模型會逐步壓縮圖片的大小;在Decoder部分中,則會逐步還原圖片的大小。同時(shí)在Encoder和Deocder間,還會使用“殘差連接”,確保Decoder部分在推理和還原圖片信息時(shí),不會丟失掉之前步驟的信息。整體過程示意圖如下,因?yàn)閴嚎s再放大的過程形似"U"字,因此被稱為UNet:
那么DDPM中的UNet,到底長什么樣子呢?我們假設(shè)輸入為一張32323大小的圖片,來看一下DDPM UNet運(yùn)作的完整流程:
如圖,左半邊為UNet的Encoder部分,右半邊為UNet的Deocder部分,最下面為MiddleBlock。我們以從上往下數(shù)第二行來分析UNet的運(yùn)作流程。
在Encoder部分的第二行,輸入是一個(gè)16*16*64的圖片,它是由上一行最右側(cè)32*32*64的圖片壓縮而來(DownSample)。對于這張16*16*64大小的圖片,在引入time_embedding后,讓它們一起過一層DownBlock,得到大小為16*16*128的圖片。再引入time_embedding,再過一次DownBlock,得到大小同樣為16*16*128的圖片。對該圖片做DowSample,就可以得到第三層的輸入,也就是大小為8*8*128的圖片。由此不難知道,同層間只做channel上的變化,不同層間做圖片的壓縮處理。至于每一層channel怎么變,層間size如何調(diào)整,就取決于實(shí)際訓(xùn)練中對模型的設(shè)定了。Decoder層也是同理。其余的信息可以參見圖片,這里不再贅述。
我們再詳細(xì)來看右下角箭頭所表示的那些模型部分,具體架構(gòu)長什么樣:
4.1 DownBlock和UpBlock
如果你曾在學(xué)習(xí)DDPM的過程中,困惑time_embedding要如何與圖片相加,Attention要在哪里做,那么這張圖可以幫助你解答這些困惑。TimeEmbedding層采用和Transformer一致的三角函數(shù)位置編碼,將常數(shù)轉(zhuǎn)變?yōu)橄蛄俊ttention層則是沿著channel維度將圖片拆分為token,做完attention后再重新組裝成圖片(注意Attention層不是必須的,是可選的,可以根據(jù)需要選擇要不要上attention)。
你可能想問:一定要沿著channel方向拆分圖片為token嗎?我可以選擇VIT那樣以patch維度拆分token,節(jié)省計(jì)算量嗎?當(dāng)然沒問題,你可以做各種實(shí)驗(yàn),這只是提供DDPM對圖片做attention的一種方法。
4.2 DownSample和UpSample
這個(gè)模塊很簡單,就是壓縮(Conv)和放大(ConvT)圖片的過程。對ConvT原理不熟悉的朋友們,可以參考這篇文章(https://blog.csdn.net/sinat_29957455/article/details/85558870)。
4.3 MiddleBlock
和DownBlock與UpBlock的過程相似,不再贅述。
到這一步,我們就把DDPM的模型核心給講完啦。在第三篇源碼解讀中,我們會結(jié)合這些架構(gòu)圖,來一起閱讀DDPM training和sampling代碼。
五、文生圖模型的一般公式
講完了DDPM,讓我們再回到開頭,看看最初我們想訓(xùn)練的那個(gè)“以文生圖”模型吧!
當(dāng)我們擁有了能夠產(chǎn)生逼真圖片的模型后,我們現(xiàn)在能進(jìn)一步用文字信息去引導(dǎo)它產(chǎn)生符合我們意圖的模型了。通常來說,文生圖模型遵循以下公式(圖片來自李宏毅老師課堂PPT):
Text Encoder:一個(gè)能對輸入文字做語義解析的Encoder,一般是一個(gè)預(yù)訓(xùn)練好的模型。在實(shí)際應(yīng)用中,CLIP模型由于在訓(xùn)練過程中采用了圖像和文字的對比學(xué)習(xí),使得學(xué)得的文字特征對圖像更加具有魯棒性,因此它的text encoder常被直接用來做文生圖模型的text encoder(比如DALLE2)
Generation Model: 輸入為文字token和圖片噪聲,輸出為一個(gè)關(guān)于圖片的壓縮產(chǎn)物(latent space)。這里通常指的就是擴(kuò)散模型,采用文字作為引導(dǎo)(guidance)的擴(kuò)散模型原理,我們將在這個(gè)系列的后文中出講解。
Decoder:用圖片的中間產(chǎn)物作為輸入,產(chǎn)出最終的圖片。Decoder的選擇也有很多,同樣也能用一個(gè)擴(kuò)散模型作為Decoder。
5.1 DALLE2
DALLE2就套用了這個(gè)公式。它曾嘗試用Autoregressive和Diffusion分別來做Generation Model,但實(shí)驗(yàn)發(fā)現(xiàn)Diffusion的效果更好。所以最后它的2和3都是一個(gè)Diffusion Model。
Stable Diffusion
大名鼎鼎Stable Diffsuion也能按這個(gè)公式進(jìn)行拆解。
5.3 Imagen
Google的Imagen,小圖生大圖,遵循的也是這個(gè)公式。
按這個(gè)套路一看,是不是文生圖模型,就不難理解了呢?我們在這個(gè)系列后續(xù)文章中,也會對這些效果驚艷的模型,進(jìn)行解讀。
-
源碼
+關(guān)注
關(guān)注
8文章
632瀏覽量
29110 -
模型
+關(guān)注
關(guān)注
1文章
3112瀏覽量
48660
原文標(biāo)題:深入淺出擴(kuò)散模型系列:基石DDPM(模型架構(gòu)篇),最詳細(xì)的DDPM架構(gòu)圖解
文章出處:【微信號:GiantPandaCV,微信公眾號:GiantPandaCV】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
相關(guān)推薦
評論