0x0. 背景
我也是偶然在知乎的一個(gè)問(wèn)題下看到這個(gè)問(wèn)題,大概就是說(shuō)在使用apex的LayerNorm/RMSNorm的時(shí)候可以打開(kāi)這個(gè)api的memory_efficient開(kāi)關(guān),這個(gè)開(kāi)關(guān)可以在速度和精度無(wú)損的情況下節(jié)省網(wǎng)絡(luò)訓(xùn)練的顯存占用。感覺(jué)比較有趣,我就研究了一下,因此也就有了這篇文章。
我去實(shí)測(cè)了一下,單機(jī)8卡A100訓(xùn)練LLama7B,純數(shù)據(jù)并行的情況下打開(kāi)memory_efficient開(kāi)關(guān)相比于不打開(kāi)節(jié)省了大約2個(gè)G的顯存,如果模型繼續(xù)scale up,那么省掉的顯存也會(huì)更多。因此,本文就是對(duì)這個(gè)memory_efficient開(kāi)關(guān)的背后實(shí)現(xiàn)做一個(gè)解讀,另外也會(huì)對(duì)apex里面LayerNorm/RMSNorm本身的cuda kernel實(shí)現(xiàn)做一個(gè)細(xì)節(jié)解讀。
apex的LayerNorm/RMSNorm被實(shí)現(xiàn)成一個(gè)fuse kernel,然后上層使用torch.autograd.Function來(lái)封裝,本文的講解主要以LayerNorm為例子
實(shí)際上RMSNorm和LayerNorm的實(shí)現(xiàn)是共享的,只不過(guò)在kernel內(nèi)部會(huì)區(qū)分一下縮放策略是2個(gè)參數(shù)(LayerNorm的gamma和beta)還是一個(gè)參數(shù)。
classFusedLayerNormAffineFunction(torch.autograd.Function): @staticmethod defforward(ctx,input,weight,bias,normalized_shape,eps,memory_efficient=False): globalfused_layer_norm_cuda iffused_layer_norm_cudaisNone: fused_layer_norm_cuda=importlib.import_module("fused_layer_norm_cuda") ctx.normalized_shape=normalized_shape ctx.eps=eps ctx.memory_efficient=memory_efficient input_=input.contiguous() weight_=weight.contiguous() bias_=bias.contiguous() output,mean,invvar=fused_layer_norm_cuda.forward_affine( input_,ctx.normalized_shape,weight_,bias_,ctx.eps ) ifctx.memory_efficient: ctx.save_for_backward(output,weight_,bias_,None,invvar) else: ctx.save_for_backward(input_,weight_,bias_,mean,invvar) returnoutput
可以看到在非memory_efficient模式下面,ctx.save_for_backward(output, weight_, bias_, None, invvar)保存了用于backward的tensor,包括輸入,權(quán)重,偏置,均值和方差的逆。但在memory_efficient模式下面ctx.save_for_backward(output, weight_, bias_, None, invvar),則是保存了輸出,權(quán)重偏置以及方差的逆。
這個(gè)地方看下你是否會(huì)掉入誤區(qū)?從表面上看,這里也就只省掉了一個(gè)gamma,因?yàn)檩斎牒洼敵鰐ensor的形狀是一樣的,那么這樣還有什么收益呢?背景是,在pre-ln的transformer架構(gòu)里面LayerNorm/RMSNorm之后緊接著是一個(gè)線性投影,無(wú)論是在注意力機(jī)制還是在多層感知機(jī)(mlp)中都是如此,所以輸出Tensor一定要被保存下來(lái)。而在post-ln架構(gòu)中,輸出還會(huì)直接用于殘差連接。然而,在這兩種情況下,LayerNorm/RMSNorm的輸入都不再被使用,所以這里原本的輸入保存變得相當(dāng)多余,因?yàn)槲覀兛梢员4鏌o(wú)論如何都會(huì)被保存的輸出張量。這樣就可以達(dá)到節(jié)省顯存的目的了。
接下來(lái)就詳細(xì)解讀下實(shí)現(xiàn)。
0x1. Apex的LayerNorm前向cuda實(shí)現(xiàn)
https://github.com/NVIDIA/apex/blob/master/csrc/layer_norm_cuda.cpp 這個(gè)文件是基于實(shí)現(xiàn)的LayerNorm cuda kernel使用torch extension模塊導(dǎo)出python接口。
同時(shí)這個(gè)文件還寫了幾個(gè)工具函數(shù),比如compute_n1_n2用來(lái)計(jì)算LayerNorm中非歸一化和歸一化部分的大?。篽ttps://github.com/BBuf/how-to-optim-algorithm-in-cuda/blob/master/apex/layer_norm_cuda.cpp#L7C31-L7C51 ,check_args函數(shù)對(duì)LayerNorm的參數(shù)進(jìn)行檢查:https://github.com/BBuf/how-to-optim-algorithm-in-cuda/blob/master/apex/layer_norm_cuda.cpp#L32C22-L143 。
此外,這個(gè)cpp預(yù)定義了cuda_layer_norm的函數(shù)接口,并且考慮了gamma/beta是否為空。
接下來(lái)就正式對(duì)LayerNorm的前向cuda實(shí)現(xiàn)進(jìn)行解析。
0x1.1 工具函數(shù)
LayerNorm使用Welford算法統(tǒng)計(jì)均值方差,在 https://github.com/NVIDIA/apex/blob/master/csrc/layer_norm_cuda_kernel.cu 寫了一系列kernel實(shí)現(xiàn)中需要用到的工具函數(shù),這些函數(shù)是gpu上用到的。下面對(duì)其簡(jiǎn)單解析一下,另外Welford算法可以看這篇博客的介紹:用Welford算法實(shí)現(xiàn)LN的方差更新(感嘆一下,zzk寫這篇文章的時(shí)候還是萌新,經(jīng)過(guò)2年時(shí)間已經(jīng)成長(zhǎng)為國(guó)內(nèi)頂級(jí)的工程師了,開(kāi)掛般學(xué)習(xí)能力) 。工具函數(shù)包含:cuWelfordOnlineSum,cuChanOnlineSum,cuRMSOnlineSum,cuChanRMSOnlineSum這些,我把自己的原始注釋使用gpt4進(jìn)行了潤(rùn)色,這樣會(huì)顯得更加通俗一些。具體解釋如下:
//這段代碼是個(gè)CUDA函數(shù),名叫cuWelfordOnlineSum,擅長(zhǎng)用Welford算法來(lái)邊收數(shù)據(jù)邊算這些數(shù)據(jù)的平均值和變化范圍(就是均值和方差)。 //用Welford算法來(lái)算這個(gè),特別穩(wěn),不會(huì)因?yàn)閿?shù)據(jù)太多而出錯(cuò),而且每加一個(gè)數(shù)據(jù)就能更新一次均值和方差。 // const U curr:這個(gè)是新來(lái)的數(shù)據(jù)點(diǎn)。 // U& mu:這個(gè)是我們到現(xiàn)在為止算出來(lái)的所有數(shù)據(jù)的平均值。 // U& sigma2:這個(gè)是我們到現(xiàn)在為止算出來(lái)的方差,可以告訴你數(shù)據(jù)變化有多大。 // U& count:這個(gè)記錄了我們到現(xiàn)在處理了多少數(shù)據(jù)點(diǎn)。 template__device__ voidcuWelfordOnlineSum( constUcurr, U&mu, U&sigma2, U&count) { count=count+U(1);//每次調(diào)用這個(gè)函數(shù),就把處理的數(shù)據(jù)數(shù)量加一。 Udelta=curr-mu;//看看新數(shù)據(jù)和現(xiàn)有平均值差多少。 Ulmean=mu+delta/count;//用這個(gè)差值和數(shù)據(jù)總量來(lái)算一個(gè)新的平均值。 mu=lmean;//把這個(gè)新算的平均值記下來(lái)。 Udelta2=curr-lmean;//現(xiàn)在再算一下新數(shù)據(jù)和新平均值的差。 sigma2=sigma2+delta*delta2;//利用這個(gè)新舊平均值的差來(lái)更新方差。 } //這段代碼是個(gè)CUDA函數(shù),名叫cuChanOnlineSum。它用于處理一種特殊的情況: //當(dāng)你有兩堆數(shù)據(jù),想要快速算出它們合并后的平均值和方差時(shí),這個(gè)函數(shù)就派上用場(chǎng)了。 // const U muB, sigma2B, countB:這三個(gè)是你新加入的那堆數(shù)據(jù)的平均值、方差和數(shù)據(jù)點(diǎn)數(shù)量。 // U& mu, sigma2, count:這三個(gè)是你之前已經(jīng)有的數(shù)據(jù)的平均值、方差和數(shù)據(jù)點(diǎn)數(shù)量。 //這個(gè)函數(shù)會(huì)更新這些值,讓它們反映出兩堆數(shù)據(jù)合并后的情況。 template __device__ voidcuChanOnlineSum( constUmuB, constUsigma2B, constUcountB, U&mu, U&sigma2, U&count) { Udelta=muB-mu;//先算算新數(shù)據(jù)堆和老數(shù)據(jù)堆的平均值差了多少。 UnA=count;//記下當(dāng)前數(shù)據(jù)堆(我們叫它A堆)的大小。 UnB=countB;//看看新來(lái)的那堆數(shù)據(jù)(B堆)有多少個(gè)點(diǎn)。 count=count+countB;//把兩堆數(shù)據(jù)的數(shù)量加起來(lái)。 UnX=count;//這就是合并后總數(shù)據(jù)量的大小。 if(nX>U(0)){ nA=nA/nX;//算一下A堆數(shù)據(jù)在總數(shù)據(jù)中占的比例。 nB=nB/nX;//同理,算一下B堆的比例。 mu=nA*mu+nB*muB;//利用這些比例和各自的平均值,算出總的平均值。 sigma2=sigma2+sigma2B+delta*delta*nA*nB*nX;//然后用一點(diǎn)復(fù)雜的公式,把方差也算出來(lái),這個(gè)公式考慮了兩堆數(shù)據(jù)的方差和它們平均值的差異。 }else{ //如果合并后的總數(shù)是0,那就說(shuō)明兩堆數(shù)據(jù)其實(shí)都是空的,所以把平均值和方差都設(shè)為0。 mu=U(0); sigma2=U(0); } } //這里定義了一個(gè)名叫cuRMSOnlineSum的CUDA函數(shù),它的主要任務(wù)就是在線實(shí)時(shí)計(jì)算一串?dāng)?shù)據(jù)的平方和。 //你可能會(huì)問(wèn),為什么要算平方和呢?這是因?yàn)槲覀兛梢杂盟鼇?lái)算出均方根(RMS, Root Mean Square), //均方根是一種描述數(shù)據(jù)波動(dòng)大小的指標(biāo),特別常用于信號(hào)處理領(lǐng)域。 template __device__ voidcuRMSOnlineSum( constUcurr, U&sigma2) { sigma2=sigma2+curr*curr;//每次函數(shù)被調(diào)用,就把當(dāng)前值的平方加到累計(jì)平方和中。 } //又定義了一個(gè)名叫cuChanRMSOnlineSum的CUDA函數(shù),這個(gè)家伙的工作就是幫你算兩組數(shù)據(jù)的平方和總和。 //當(dāng)你有兩組數(shù)據(jù),想要快速合并它們的均方根(RMS)時(shí),這個(gè)函數(shù)就能派上用場(chǎng)。 //它其實(shí)是均方根計(jì)算過(guò)程中的一個(gè)環(huán)節(jié),用于處理兩個(gè)獨(dú)立數(shù)據(jù)集的情況。 template __device__ voidcuChanRMSOnlineSum( constUsigma2B, U&sigma2) { sigma2=sigma2+sigma2B;//這里就簡(jiǎn)單直接了,把第二組數(shù)據(jù)的平方和加到當(dāng)前的累計(jì)值上。 }
這里還有一個(gè)函數(shù)cuWelfordMuSigma2是用來(lái)計(jì)算張量某一維度上的均值(mu)和方差(sigma2)的,它調(diào)用了上面的工具函數(shù),但是這個(gè)函數(shù)我們?cè)趉ernel實(shí)現(xiàn)階段解析,因?yàn)樗枰恍﹌ernel啟動(dòng)的背景。
0x1.2 啟動(dòng)邏輯
先對(duì)kernel啟動(dòng)這部分的代碼進(jìn)行注釋,首先是共享內(nèi)存的結(jié)構(gòu)體定義。
//這段代碼定義了一個(gè)叫做SharedMemory的模板結(jié)構(gòu)體,專門用在CUDA設(shè)備函數(shù)里來(lái)訪問(wèn)所謂的“共享內(nèi)存”。 //在CUDA編程里,共享內(nèi)存是一種特別高效的內(nèi)存類型,非常適合用來(lái)在CUDA的一個(gè)塊(block)內(nèi)的不同線程間共享數(shù)據(jù)。 //這里還包括了針對(duì)float和double類型數(shù)據(jù)的SharedMemory結(jié)構(gòu)體的特化版本。 namespace{ //這是通用的SharedMemory結(jié)構(gòu)體模板。注意,我們通過(guò)在函數(shù)體內(nèi)使用一個(gè)未定義的符號(hào)來(lái)阻止這個(gè)結(jié)構(gòu)體被實(shí)例化, //這樣如果嘗試用未特化的類型來(lái)編譯這個(gè)結(jié)構(gòu)體,編譯器就會(huì)報(bào)錯(cuò)。 //template//structSharedMemory //{ ////確保我們不會(huì)編譯任何未特化的類型 //__device__T*getPointer() //{ //extern__device__voiderror(void); //error(); //returnNULL; //} //}; template structSharedMemory; //這是SharedMemory結(jié)構(gòu)體針對(duì)float類型的特化版本。 template<> structSharedMemory { //這個(gè)函數(shù)返回一個(gè)指向共享內(nèi)存的float類型指針。 __device__float*getPointer() { //這里聲明了一個(gè)名為s_float的外部共享內(nèi)存數(shù)組,用于存儲(chǔ)float類型的數(shù)據(jù)。 // extern和__shared__關(guān)鍵字表明這個(gè)數(shù)組是在共享內(nèi)存中定義的。 extern__shared__floats_float[]; returns_float; } }; //下面是針對(duì)double類型的特化版本,工作方式和float版本相似。 template<> structSharedMemory { __device__double*getPointer() { extern__shared__doubles_double[]; returns_double; } }; }
然后是Kernel啟動(dòng)的具體邏輯部分:
//這段代碼里,我們定義了一個(gè)CUDA設(shè)備函數(shù)叫做cuApplyLayerNorm_,它的主要任務(wù)是執(zhí)行LayerNorm(層歸一化)。 //層歸一化是深度學(xué)習(xí)中的一個(gè)技巧,用來(lái)讓每一層的輸出更加標(biāo)準(zhǔn)化,有助于模型訓(xùn)練。 //我們定義了三種模板參數(shù):T是輸入數(shù)據(jù)類型,U是中間計(jì)算(比如均值和方差)的類型,V是輸出數(shù)據(jù)類型。 // output_vals, mean, invvar, vals, gamma, beta 這些都是指向不同數(shù)據(jù)的指針。 //在層歸一化中,我們通常把一個(gè)多維數(shù)據(jù)(張量)分為兩部分:一部分用來(lái)做標(biāo)準(zhǔn)化,另一部分保持原樣。 //比如,如果你有一個(gè)[batch_size,channels,height,width]形狀的4D張量, //而你只想對(duì)最后兩個(gè)維度進(jìn)行層歸一化,那么n1是batch_size * channels,n2是height * width。 template__device__ voidcuApplyLayerNorm_( V*__restrict__output_vals, U*__restrict__mean, U*__restrict__invvar, constT*__restrict__vals, constintn1, constintn2, constUepsilon, constV*__restrict__gamma, constV*__restrict__beta, boolrms_only ) { //基本假設(shè): // 1) blockDim.x 是 warp 的大?。ㄟ@是一個(gè)CUDA的技術(shù)細(xì)節(jié))。 // 2)輸入的張量數(shù)據(jù)在內(nèi)存中是連續(xù)的。 // //這段代碼遍歷n1維度,每次處理一個(gè)i1索引。 //假設(shè)每個(gè)CUDA線程塊的x維度等于warp大小,確保數(shù)據(jù)處理是高效的。 //這里一個(gè)線程可能要處理多行,所以我們用gridDim.y來(lái)控制步長(zhǎng)。(因?yàn)間ridDim.x=1) for(autoi1=blockIdx.y;i1shared; U*buf=shared.getPointer();//創(chuàng)建一個(gè) SharedMemory 實(shí)例用于處理類型 U 的數(shù)據(jù)。 Umu,sigma2;//這里mu和sigma2分別代表均值和方差,我們接下來(lái)要計(jì)算它們。 //調(diào)用 cuWelfordMuSigma2 函數(shù)計(jì)算給定索引 i1 處的均值(mu)和方差(sigma2)。 cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf,rms_only); //定位到當(dāng)前 i1 索引處的輸入和輸出的起始位置。 constT*lvals=vals+i1*n2; V*ovals=output_vals+i1*n2; //計(jì)算逆方差 c_invvar,這是層歸一化中一個(gè)關(guān)鍵的步驟。 Uc_invvar=rsqrt(sigma2+epsilon); //計(jì)算每個(gè) CUDA 塊的線程總數(shù)(numx)和當(dāng)前線程的一維索引(thrx)。 constintnumx=blockDim.x*blockDim.y; constintthrx=threadIdx.x+threadIdx.y*blockDim.x; //如果提供了gamma和beta參數(shù),或者我們只是在做RMS計(jì)算,我們會(huì)用一種特別的方式來(lái)計(jì)算輸出值。 if(gamma!=NULL&&(beta!=NULL||rms_only)){ for(inti=thrx;i(lvals[i]); if(!rms_only){ //標(biāo)準(zhǔn)化當(dāng)前值,然后用gamma和beta進(jìn)行調(diào)整。 ovals[i]=gamma[i]*static_cast (c_invvar*(curr-mu))+beta[i]; }else{ ////如果是RMS模式,我們稍微簡(jiǎn)化計(jì)算過(guò)程。 ovals[i]=gamma[i]*static_cast (c_invvar*curr); } } } //否則,如果沒(méi)有提供gamma和beta,我們就直接用計(jì)算出的均值和逆方差來(lái)進(jìn)行標(biāo)準(zhǔn)化。 else{ for(inti=thrx;i(lvals[i]); if(!rms_only){ //直接進(jìn)行標(biāo)準(zhǔn)化計(jì)算。 ovals[i]=static_cast (c_invvar*(curr-mu)); }else{ //// RMS模式下的簡(jiǎn)化計(jì)算。 ovals[i]=static_cast (c_invvar*curr); } } } //在每個(gè) CUDA 塊中,僅由一個(gè)線程(線程(0,0))更新均值和逆方差。 if(threadIdx.x==0&&threadIdx.y==0){ if(!rms_only){ mean[i1]=mu; } invvar[i1]=c_invvar; } //用于同步塊內(nèi)的所有線程。 __syncthreads(); } } //對(duì)上個(gè)函數(shù)的參數(shù)透?jìng)鳎贿^(guò)rms_only設(shè)為False template __global__ voidcuApplyLayerNorm( V*__restrict__output_vals, U*__restrict__mean, U*__restrict__invvar, constT*__restrict__vals, constintn1, constintn2, constUepsilon, constV*__restrict__gamma, constV*__restrict__beta ) { cuApplyLayerNorm_ (output_vals,mean,invvar,vals,n1,n2,epsilon,gamma,beta,false); } //kernel啟動(dòng)代碼,設(shè)置線程塊和線程數(shù) template voidHostApplyLayerNorm( V*output, U*mean, U*invvar, constT*input, intn1, intn2, doubleepsilon, constV*gamma, constV*beta ) { // threads和blocks定義了CUDA內(nèi)核的線程和塊的維度。這里,每個(gè)線程塊有32×4的線程,而塊的數(shù)量由n1和GPU設(shè)備的最大網(wǎng)格大小限制決定。 autostream=at::getCurrentCUDAStream().stream(); constdim3threads(32,4,1); constuint64_tmaxGridY=at::getCurrentDeviceProperties()->maxGridSize[1]; constdim3blocks(1,std::min((uint64_t)n1,maxGridY),1); //這段代碼計(jì)算內(nèi)核函數(shù)需要多少共享內(nèi)存。如果threads.y大于1,它會(huì)根據(jù)U類型的大小分配足夠的內(nèi)存。 intnshared= threads.y>1? threads.y*sizeof(U)+(threads.y/2)*sizeof(U): 0; //最后,函數(shù)使用cuApplyLayerNorm kernel來(lái)執(zhí)行實(shí)際的LayerNorm操作。 // kernel函數(shù)的調(diào)用使用了之前計(jì)算的線程塊和線程配置,以及共享內(nèi)存大小和CUDA流。 cuApplyLayerNorm<< >>( output,mean,invvar,input,n1,n2,U(epsilon),gamma,beta); }
這段代碼包含了kernel的啟動(dòng)邏輯,包括設(shè)置block的個(gè)數(shù)以及每個(gè)block中的線程排布方式,然后在cuApplyLayerNorm_里面有一個(gè)跨線程網(wǎng)格的大循環(huán)作用在n1維度,每個(gè)線程可能會(huì)處理多行數(shù)據(jù)。而在每一行數(shù)據(jù)的處理上,調(diào)用了cuWelfordMuSigma2 函數(shù)計(jì)算給定索引 i1 處的均值(mu)和方差(sigma2),并隨后在n2維度上來(lái)計(jì)算LayerNorm的輸出,同時(shí)會(huì)在每個(gè)Block的線程(0, 0)更新cuWelfordMuSigma2算出來(lái)的均值和方差(這里的記錄的實(shí)際上是方差的逆)。
0x1.3 kernel實(shí)現(xiàn)
從上面的分析可知,整個(gè)LayerNorm實(shí)現(xiàn)的核心就是cuWelfordMuSigma2函數(shù),下面對(duì)這個(gè)函數(shù)進(jìn)行解析。
//`cuWelfordMuSigma2`是一個(gè)CUDA設(shè)備函數(shù),旨在高效計(jì)算張量某一特定維度上的均值(mu)和方差(sigma2)。 //它基于Welford算法實(shí)現(xiàn),以提高數(shù)值穩(wěn)定性。此外,該函數(shù)支持僅計(jì)算均方根(RMS)作為一種操作模式。 //模板參數(shù):定義了處理張量值(T)和執(zhí)行計(jì)算(U)時(shí)使用的數(shù)據(jù)類型。 // const T*__restrict__ vals:指向張量數(shù)據(jù)的指針。 // const int n1, n2:指定張量的維度,其中n1是參與計(jì)算的維度的大小,n2是被約減的維度的大小。 // const int i1:當(dāng)前正在處理的n1維度上的特定索引。 // U& mu, sigma2:用于存儲(chǔ)計(jì)算得出的均值和方差。 // U* buf:指向用于線程間通訊的共享內(nèi)存緩沖區(qū)的指針。 // bool rms_only:一個(gè)標(biāo)志,用于指示是否僅計(jì)算RMS(為true時(shí))或同時(shí)計(jì)算均值和方差(為false時(shí))。 template __device__ voidcuWelfordMuSigma2( constT*__restrict__vals, constintn1, constintn2, constinti1, U&mu, U&sigma2, U*buf, boolrms_only) { //前提條件: // 1) blockDim.x 等于 warp 的大小。 // 2)輸入的張量在內(nèi)存中連續(xù)存儲(chǔ)。 // 3)有足夠的共享內(nèi)存可用,大小為 2*blockDim.y*sizeof(U)+ blockDim.y*sizeof(int)。 // //在 n2 維度上計(jì)算方差和均值。 //初始化 count, mu, 和 sigma2 為零。 Ucount=U(0); mu=U(0); sigma2=U(0); //確保處理的 i1 索引在張量的有效范圍內(nèi)。 if(i1(lvals[l+k]); //根據(jù) rms_only 標(biāo)志調(diào)用相應(yīng)的函數(shù)來(lái)更新均值和方差或僅更新平方和(用于計(jì)算 RMS)。 if(!rms_only){ cuWelfordOnlineSum(curr,mu,sigma2,count); }else{ cuRMSOnlineSum(curr,sigma2); } } } //這個(gè)循環(huán)處理了之前在步長(zhǎng)為 4*numx 的循環(huán)中未處理的張量元素。每個(gè)線程獨(dú)立處理它們剩余的部分。 for(;l(lvals[l]); if(!rms_only){ cuWelfordOnlineSum(curr,mu,sigma2,count); }else{ cuRMSOnlineSum(curr,sigma2); } } //在同一個(gè)warp內(nèi)進(jìn)行歸約操作。 for(intl=0;l<=?4;??++l)?{ ??????//?是在 CUDA 設(shè)備上進(jìn)行 warp 內(nèi)部數(shù)據(jù)交換的關(guān)鍵部分。 ??????//?這行代碼用于確定在一個(gè) warp(32個(gè)線程)內(nèi),每個(gè)線程應(yīng)該從哪個(gè)“l(fā)ane”(即其他線程)獲取數(shù)據(jù)。 ??????//?(1< (muB,sigma2B,countB,mu,sigma2,count); }else{ cuChanRMSOnlineSum(sigma2B,sigma2); } } //threadIdx.x==0hascorrectvaluesforeachwarp //inter-warpreductions //檢查是否有多個(gè) warp。如果 blockDim.y 大于 1,則表示塊中有多個(gè) warp 需要進(jìn)行reduce操作。 if(blockDim.y>1){ //為方差和均值的reduce操作分配共享內(nèi)存。ubuf 用于存儲(chǔ)方差和均值,ibuf 用于存儲(chǔ)計(jì)數(shù)。 U*ubuf=(U*)buf; U*ibuf=(U*)(ubuf+blockDim.y); //這個(gè)循環(huán)是對(duì) warp 間的reduce操作進(jìn)行分層合并。 for(intoffset=blockDim.y/2;offset>0;offset/=2){ //upperhalfofwarpswritetoshared //確保只有部分線程(warp 的上半部分)將其計(jì)算的結(jié)果寫入共享內(nèi)存。 if(threadIdx.x==0&&threadIdx.y>=offset&&threadIdx.y2*offset)?{ ??????????const?int?wrt_y?=?threadIdx.y?-?offset; ??????????if?(!rms_only)?{ ????????????ubuf[2*wrt_y]?=?mu; ????????????ibuf[wrt_y]?=?count; ??????????} ??????????ubuf[2*wrt_y+1]?=?sigma2; ????????} ????????//?同步以等待共享內(nèi)存存儲(chǔ)完畢 ????????__syncthreads(); ????????//?lower?half?merges ????????//?此部分是對(duì) warp 間數(shù)據(jù)的合并操作。 ????????//?確保只有部分線程(warp 的下半部分)從共享內(nèi)存中讀取數(shù)據(jù)并進(jìn)行合并。 ????????if?(threadIdx.x?==?0?&&?threadIdx.y?(muB,sigma2B,countB,mu,sigma2,count); }else{ cuChanRMSOnlineSum(sigma2B,sigma2); } } __syncthreads(); } //threadIdx.x=0&&threadIdx.y==0onlythreadthathascorrectvalues //最終的結(jié)果由塊內(nèi)的第一個(gè)線程(threadIdx.x ==0&& threadIdx.y ==0)計(jì)算并寫入共享內(nèi)存。 if(threadIdx.x==0&&threadIdx.y==0){ if(!rms_only){ ubuf[0]=mu; } ubuf[1]=sigma2; } __syncthreads(); //如果不是只計(jì)算 RMS,則還需要更新均值 mu。 if(!rms_only){ mu=ubuf[0]; } //計(jì)算最終的方差。 sigma2=ubuf[1]/U(n2); //don'tcareaboutfinalvalueofcount,weknowcount==n2 } //如果塊中只有一個(gè) warp(blockDim.y == 1),則通過(guò) WARP_SHFL 直接在 warp 內(nèi)進(jìn)行數(shù)據(jù)交換和更新。 else{ if(!rms_only){ mu=WARP_SHFL(mu,0); } sigma2=WARP_SHFL(sigma2/U(n2),0); } }
cuWelfordMuSigma2函數(shù)就是在n2維度上使用工具函數(shù)章節(jié)的Weleford方法來(lái)完成均值和方差的計(jì)算,然后這里還借助了共享內(nèi)存來(lái)做warp內(nèi)和warp間的reduce,最終得到全局的均值和方差。
前向的kernel就分析到這里,大家如果想對(duì)LayerNorm的優(yōu)化做進(jìn)一步的了解,推薦看一下OneFlow的SoftMax和LayerNorm優(yōu)化文章。CUDA優(yōu)化之LayerNorm性能優(yōu)化實(shí)踐(https://zhuanlan.zhihu.com/p/443026261) ,這篇文章也是講解了LayerNorm的前向優(yōu)化流程,文章開(kāi)頭有一張性能的圖:
實(shí)際上在大模型時(shí)代,我們的隱藏層維度已經(jīng)越來(lái)越大了,所以我們?cè)趯?shí)際訓(xùn)練的時(shí)候,OneFlow版本的kernel相比于apex的layerNorm在13B之類的模型訓(xùn)練里就拿不到明顯收益了。而在CV中,由于做LayerNorm的維度可能相對(duì)小一些,所以相比于apex的LayerNorm就可以取得明顯加速。
0x2. Apex的LayerNorm反向cuda實(shí)現(xiàn)(memory_efficient相關(guān)計(jì)算)
在apex的LayerNorm反向?qū)崿F(xiàn)時(shí)我們不僅要關(guān)注它的cuda kernel是怎么寫的,還要關(guān)注memory_efficient打開(kāi)時(shí)是如何根據(jù)輸出來(lái)計(jì)算梯度的。我們知道LayerNorm需要對(duì)輸入,gamma,beta都計(jì)算梯度,介于篇幅原因,這里對(duì)實(shí)現(xiàn)得最復(fù)雜的gamma/beta的反向過(guò)程進(jìn)行走讀。
0x2.1 啟動(dòng)邏輯
這里從kernel的啟動(dòng)邏輯開(kāi)始梳理:
//這是一個(gè)模板函數(shù),支持不同的數(shù)據(jù)類型:T(輸入數(shù)據(jù)類型)、 // U(通常用于中間計(jì)算的數(shù)據(jù)類型,默認(rèn)為float)、V(輸出數(shù)據(jù)類型,默認(rèn)與T相同)。 //參數(shù)包括輸出梯度(dout)、均值(mean)、方差倒數(shù)(invvar)、輸入或輸出的PyTorch張量(input_or_output)、 //兩個(gè)維度參數(shù)(n1、n2)、gamma和beta參數(shù)、用于數(shù)值穩(wěn)定的epsilon、輸入梯度(grad_input)、 // gamma梯度(grad_gamma)和beta梯度(grad_beta)、以及一個(gè)指示是否優(yōu)化內(nèi)存使用的布爾值(memory_efficient)。 templatevoidHostLayerNormGradient( constV*dout, constU*mean, constU*invvar, at::Tensor*input_or_output, intn1, intn2, constV*gamma, constV*beta, doubleepsilon, T*grad_input, V*grad_gamma, V*grad_beta, boolmemory_efficient ) { //獲取當(dāng)前CUDA流以用于后續(xù)的CUDA內(nèi)核調(diào)用。 autostream=at::getCurrentCUDAStream().stream(); //如果gamma和beta不為NULL,函數(shù)會(huì)計(jì)算它們的梯度。 //這涉及兩個(gè)CUDA內(nèi)核的調(diào)用:cuComputePartGradGammaBeta和cuComputeGradGammaBeta。 if(gamma!=NULL&&beta!=NULL){ //computegrad_gamma(j)andgrad_beta(j) // part_size是分塊計(jì)算梯度時(shí)的部分大小。 constintpart_size=16; // threads2定義了每個(gè)CUDA線程塊中的線程數(shù)量(32×4×1)。 constdim3threads2(32,4,1); // blocks2定義了CUDA網(wǎng)格中的塊數(shù)量,其中,n2維度被分成多個(gè)塊,以確保每個(gè)塊可以處理n2中的一部分。 constdim3blocks2((n2+threads2.x-1)/threads2.x,part_size,1); //這部分代碼計(jì)算用于CUDA內(nèi)核的共享內(nèi)存大小。nshared2_a和nshared2_b是基于線程和塊維度的兩種不同共享內(nèi)存大小估算。 constintnshared2_a=2*sizeof(U)*threads2.y*threads2.y*(threads2.x+1); constintnshared2_b=threads2.x*threads2.y*sizeof(U); //最終選擇較大的一個(gè)估算值作為實(shí)際的共享內(nèi)存大?。╪shared2)。 constintnshared2=nshared2_a>nshared2_b?nshared2_a:nshared2_b; //note(mkozuki):Icanhardcodepart_grad_gamma'sdtypeasfloatgiventhat //the`cuda_layer_norm_gradient`doesn'tsupportdouble. //根據(jù)輸入或輸出張量的數(shù)據(jù)類型決定局部梯度張量part_grad_gamma和part_grad_beta的數(shù)據(jù)類型。 //如果輸入或輸出是半精度浮點(diǎn)數(shù)(Half)或BFloat16,則使用單精度浮點(diǎn)數(shù)(Float);否則,使用輸入或輸出的相同數(shù)據(jù)類型。 constautopart_grad_dtype= (input_or_output->scalar_type()==at::Half||input_or_output->scalar_type()==at::BFloat16)? at::Float: input_or_output->scalar_type(); //創(chuàng)建兩個(gè)新的PyTorch張量part_grad_gamma和part_grad_beta,用于存儲(chǔ)gamma和beta的局部梯度計(jì)算結(jié)果。 at::Tensorpart_grad_gamma=at::empty({part_size,n2},input_or_output->options().dtype(part_grad_dtype)); at::Tensorpart_grad_beta=at::empty_like(part_grad_gamma); //使用BOOL_SWITCH宏處理memory_efficient參數(shù),以決定是否使用內(nèi)存高效版本的CUDA內(nèi)核。 //調(diào)用cuComputePartGradGammaBeta內(nèi)核計(jì)算gamma和beta的梯度。 //這個(gè)內(nèi)核函數(shù)接收必要的輸入?yún)?shù),并將梯度結(jié)果寫入part_grad_gamma和part_grad_beta張量。 BOOL_SWITCH(memory_efficient,MemoryEfficient,[&]{ autokernel=&cuComputePartGradGammaBeta ; kernel<< >>( dout, input_or_output->DATA_PTR (), n1,n2, mean, invvar, U(epsilon), gamma, beta, part_grad_gamma.DATA_PTR(), part_grad_beta.DATA_PTR(), epsilon, false); }); //定義了每個(gè)CUDA線程塊中的線程數(shù)量(32×8×1)。 constdim3threads3(32,8,1); //定義了CUDA網(wǎng)格中的塊數(shù)量。在這里,n2維度被分成多個(gè)塊,每個(gè)塊的大小由threads2.x(之前定義的線程數(shù)量)確定。 constdim3blocks3((n2+threads2.x-1)/threads2.x,1,1); //這行代碼計(jì)算了cuComputeGradGammaBeta內(nèi)核所需的共享內(nèi)存大小。它基于threads3線程塊的維度和數(shù)據(jù)類型U的大小。 constintnshared3=threads3.x*threads3.y*sizeof(U); //kernel接收局部梯度張量(part_grad_gamma和part_grad_beta)、塊大?。╬art_size)、 //維度參數(shù)(n1和n2)和指向梯度輸出的指針(grad_gamma和grad_beta)。 cuComputeGradGammaBeta<< >>( part_grad_gamma.DATA_PTR(), part_grad_beta.DATA_PTR(), part_size, n1,n2, grad_gamma, grad_beta, false); } ... }
這里省略了計(jì)算輸入梯度的啟動(dòng)代碼,只看計(jì)算gamma和beta梯度的代碼??梢园l(fā)現(xiàn),這里對(duì)gamma和beta的梯度進(jìn)行計(jì)算時(shí)使用了分塊計(jì)算的方式,首先會(huì)調(diào)用cuComputePartGradGammaBeta這個(gè)kernel計(jì)算出一個(gè)部分gamma和部分beta,也就是part_grad_gamma和part_grad_beta,需要注意這個(gè)kernel開(kāi)啟的線程塊為:const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1),其中part_size=16,此外每個(gè)線程塊中的線程排布為:const dim3 threads2(32,4,1),即每個(gè)線程塊有128個(gè)線程。我們可以簡(jiǎn)單算一下block2的大小,threads2.x=32,那么blocks2=(n2/32,16,1),也就是一共會(huì)有n2/2個(gè)線程塊。
使用cuComputePartGradGammaBeta計(jì)算完局部gamma和beta的grad之后,會(huì)調(diào)用cuComputeGradGammaBeta這個(gè)kernel來(lái)匯總?cè)值膅amma和beta的梯度。這里開(kāi)啟的線程塊為:const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1),而每個(gè)線程塊里面有256個(gè)線程,排布為const dim3 threads3(32,8,1)。
現(xiàn)在了解了線程塊的組織方式就需要去kernel實(shí)現(xiàn)里面對(duì)應(yīng)看一下具體是怎么計(jì)算的。
0x2.2 kernel計(jì)算邏輯
首先來(lái)看分段計(jì)算gamma和beta梯度的kernel實(shí)現(xiàn),注釋如下:
// part_size是分塊計(jì)算梯度時(shí)的部分大小。 //constintpart_size=16; // threads2定義了每個(gè)CUDA線程塊中的線程數(shù)量(32×4×1)。 //constdim3threads2(32,4,1); // blocks2定義了CUDA網(wǎng)格中的塊數(shù)量,其中,n2維度被分成多個(gè)塊,以確保每個(gè)塊可以處理n2中的一部分。 //constdim3blocks2((n2+threads2.x-1)/threads2.x,part_size,1); //-> //blockDim.x=32,blockDim.y=4,gridDim.y=16 //假設(shè)n1=4,n2=256,并且當(dāng)前是第一個(gè)線程塊 template__global__ voidcuComputePartGradGammaBeta( constV*__restrict__dout, constT*__restrict__input_or_output, constintn1, constintn2, constU*__restrict__mean, constU*__restrict__invvar, Uepsilon, constV*__restrict__gamma, constV*__restrict__beta, U*part_grad_gamma, U*part_grad_beta, constdoubleeps, boolrms_only) { // numsegs_n1計(jì)算n1維度(4)被分成多少段。使用blockDim.y*blockDim.y(16)作為分段大小。 //帶入值:numsegs_n1 =(4 + 16 - 1)/ 16 = 1。 constintnumsegs_n1=(n1+blockDim.y*blockDim.y-1)/(blockDim.y*blockDim.y); // segs_per_block計(jì)算每個(gè)線程塊要處理的段數(shù)。 //帶入值:segs_per_block =(1 + 16 - 1)/ 16 = 1。 constintsegs_per_block=(numsegs_n1+gridDim.y-1)/gridDim.y; //這些行計(jì)算當(dāng)前線程塊開(kāi)始和結(jié)束處理n1維度的索引 //i1_beg和i1_beg_plus_one相差segs_per_block*blockDim.y*blockDim.y=1*4*4=16 //帶入blockIdx.y =0:i1_beg =0* 1 * 4 * 4 =0, i1_beg_plus_one = 1 * 1 * 4 * 4 = 16,i1_end = min(16, 4)= 4 constinti1_beg=blockIdx.y*segs_per_block*blockDim.y*blockDim.y; constinti1_beg_plus_one=(blockIdx.y+1)*segs_per_block*blockDim.y*blockDim.y; constinti1_end=i1_beg_plus_oneshared; U*buf=shared.getPointer();//bufhasatleastblockDim.x*blockDim.y*blockDim.y+(blockDim.y-1)*(blockDim.x/blockDim.y)elements U*warp_buf1=(U*)buf;//大小是31*4*4=496 U*warp_buf2=warp_buf1+blockDim.y*blockDim.y*row_stride;//大小是3*(32/4)=24 //computepartialsumsfromstridedinputs //dothistoincreasenumberofloadsinflight cuLoadWriteStridedInputs (i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input_or_output,dout,i1_end,n2,mean,invvar,gamma,beta,eps,rms_only); // for循環(huán)處理每個(gè)數(shù)據(jù)塊(由i1_beg和i1_end確定)。 //它在數(shù)據(jù)塊之間以步幅blockDim.y*blockDim.y迭代,允許不同的線程塊處理不同的數(shù)據(jù)區(qū)域。 for(inti1_block=i1_beg+blockDim.y*blockDim.y;i1_block(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input_or_output,dout,i1_end,n2,mean,invvar,gamma,beta,eps,rms_only); } //確保在所有線程完成其加載和處理操作之前,沒(méi)有線程會(huì)繼續(xù)執(zhí)行后續(xù)的操作。 __syncthreads(); //inter-warpreductions //sumwithineachwarp //這部分代碼執(zhí)行內(nèi)部歸約,計(jì)算每個(gè)warp內(nèi)部的部分和。 // acc1和acc2分別用于累積來(lái)自warp_buf1和warp_buf2的值。這些緩沖區(qū)包含之前步驟計(jì)算的中間結(jié)果。 Uacc1=U(0); Uacc2=U(0); //內(nèi)部循環(huán)對(duì)于blockDim.y內(nèi)的每一行進(jìn)行累加,if (!rms_only)條件檢查是否需要執(zhí)行特定的分支邏輯。 //需要特別注意,這個(gè)累加實(shí)際上是在列方向上也就是n2維度,在n2維度上一個(gè)線程負(fù)責(zé)計(jì)算blockDim.y列 for(intk=0;k1;offset/=2){ //在每次迭代中,只有threadIdx.y小于當(dāng)前offset的線程會(huì)參與計(jì)算,這樣可以避免重復(fù)的工作。 if(threadIdx.y
在理解這段代碼之前,有一個(gè)大前提,那就是這里的訪問(wèn)方式是n1是和blockDim.y綁定的,而n2是和blockDim.x綁定的,而且以二維矩陣的角度來(lái)看,n1是在列方向,而n2是在行的方向。然后const int row_stride = blockDim.x+1這一行是對(duì)共享內(nèi)存進(jìn)行padding避免Bank Conflict的,而在計(jì)算時(shí)對(duì)共享內(nèi)存的訪問(wèn)就是按照列來(lái)訪問(wèn),徹底避免bank conflict。
這也是為什么cuLoadWriteStridedInputs和cuLoadAddStridedInputs函數(shù)名中有一個(gè)Strided,這也暗示了它們的訪問(wèn)模式是跨stride的。剩下的部分其實(shí)和前向就比較類似了,做warp內(nèi)和warp間的reduce。
另外一個(gè)值得注意的點(diǎn)是在cuLoadWriteStridedInputs和cuLoadAddStridedInputs計(jì)算時(shí),會(huì)根據(jù)memory_efficient開(kāi)關(guān)選擇不同的計(jì)算公式,分別從輸入和輸出來(lái)計(jì)算出梯度,達(dá)到kernel內(nèi)部重計(jì)算的目的。
//這段代碼定義了一個(gè)名為cuLoadWriteStridedInputs的CUDA設(shè)備函數(shù)模板,用于在計(jì)算LayerNorm的梯度時(shí), //從輸入張量中加載數(shù)據(jù)并進(jìn)行必要的計(jì)算,將結(jié)果存儲(chǔ)在 warp 緩沖區(qū)中。這個(gè)函數(shù)支持內(nèi)存高效模式(MemoryEfficient)。 //模板參數(shù) T, U, V 代表不同的數(shù)據(jù)類型。 // bool MemoryEfficient 用于選擇是否采用內(nèi)存高效的方式處理數(shù)據(jù)。 //__device__表明這是一個(gè) CUDA 設(shè)備函數(shù)。 //函數(shù)參數(shù)包括各種用于LayerNorm梯度計(jì)算的數(shù)據(jù), //如輸入/輸出張量、梯度張量 dout、均值 mean、逆方差 invvar、縮放參數(shù) gamma、偏移參數(shù) beta 等。 template__device__ voidcuLoadWriteStridedInputs( constinti1_block, constintthr_load_row_off, constintthr_load_col_off, constinti2_off, constintrow_stride, U*warp_buf1, U*warp_buf2, constT*input_or_output, constV*dout, constinti1_end, constintn2, constU*__restrict__mean, constU*__restrict__invvar, constV*__restrict__gamma, constV*__restrict__beta, constdoubleeps, boolrms_only ) { //計(jì)算 i1,表示當(dāng)前處理的行索引。 inti1=i1_block+thr_load_row_off; if(i1(input_or_output[load_idx]); Ucurr_dout=static_cast(dout[load_idx]); //根據(jù) rms_only 和 MemoryEfficient 的值,使用不同的公式計(jì)算梯度,并將結(jié)果存儲(chǔ)在 warp 緩沖區(qū)中。 if(!rms_only){ warp_buf1[write_idx]=curr_dout; if(MemoryEfficient){ Ucurr_beta=static_cast(beta[i2]); warp_buf2[write_idx]=curr_dout*(c_h-curr_beta)/static_cast(clamp_by_magnitude(gamma[i2],eps)); }else{ warp_buf2[write_idx]=curr_dout*(c_h-mean[i1])*invvar[i1]; } }else{ if(MemoryEfficient){ warp_buf2[write_idx]=curr_dout*(c_h)/static_cast(clamp_by_magnitude(gamma[i2],eps)); }else{ warp_buf2[write_idx]=curr_dout*(c_h)*invvar[i1]; } } }else{ //對(duì)于超出 n2 范圍的索引,將相應(yīng)的 warp 緩沖區(qū)位置設(shè)置為0。 if(!rms_only){ warp_buf1[write_idx]=U(0); } warp_buf2[write_idx]=U(0); } } }else{ //對(duì)于超出 n1 范圍的索引,也將相應(yīng)的 warp 緩沖區(qū)位置設(shè)置為0。 for(intk=0;k
執(zhí)行完cuComputePartGradGammaBeta這個(gè)kernel之后,它的輸出part_grad_gamma和part_grad_beta分別以行為n2列為n1的內(nèi)存視角保存了LayerNorm的局部均值和方差的梯度。
接下來(lái)會(huì)使用cuComputeGradGammaBeta這個(gè)kernel來(lái)計(jì)算全局的均值和方差的梯度,由于局部計(jì)算的時(shí)候分塊大小是16,而每個(gè)線程負(fù)責(zé)了4行的計(jì)算,那么這里還需要累積16/4=4次,以得到當(dāng)前行的所有局部梯度的和。
//blockDim.x=n2/32,blockDim.y=1 //threadDim.x=32,threadDim.y=8 template__global__ voidcuComputeGradGammaBeta( constU*part_grad_gamma, constU*part_grad_beta, constintpart_size, constintn1, constintn2, V*grad_gamma, V*grad_beta, boolrms_only) { //sumpartialgradientsforgammaandbeta SharedMemoryshared; U*buf=shared.getPointer(); //計(jì)算每個(gè)線程的全局索引i2,用于確定它在n2維度上的位置。 inti2=blockIdx.x*blockDim.x+threadIdx.x; //如果線程索引i2小于n2的大小,該線程會(huì)參與計(jì)算。 if(i2=1;offset/=2){ //tophalfwritetosharedmemory //在這個(gè)歸約階段,線程首先將其累加結(jié)果寫入共享內(nèi)存,然后從共享內(nèi)存讀取并繼續(xù)累加。 if(threadIdx.y>=offset&&threadIdx.y2*offset)?{ ??????????const?int?write_idx?=?(threadIdx.y?-?offset)?*?blockDim.x?+?threadIdx.x; ??????????buf[write_idx]?=?sum_gamma; ??????????if?(!rms_only)?{ ????????????buf[write_idx+nbsize3]?=?sum_beta; ??????????} ????????} ????????//?__syncthreads()在每次迭代結(jié)束時(shí)同步所有線程,確保共享內(nèi)存的一致性。 ????????__syncthreads(); ????????//?bottom?half?sums ????????if?(threadIdx.y?
注意,for (int offset = blockDim.y/2; offset >= 1; offset /= 2) 這個(gè)循環(huán)包起來(lái)的代碼在這里不會(huì)工作,因?yàn)檫@個(gè)kernel的啟動(dòng)設(shè)置中 blockDim.y=1。另外,我們知道輸入的數(shù)據(jù)已經(jīng)是寫到全局內(nèi)存里面的了,已經(jīng)是同步之后的了,然后每個(gè)線程累積4次這個(gè)過(guò)程也是從global memory里面先讀再計(jì)算最后寫回全局內(nèi)存,所以確實(shí)不需要再reduce了。
關(guān)于memory_efficient開(kāi)關(guān)打開(kāi)時(shí)的梯度計(jì)算公式,按照 https://github.com/NVIDIA/apex/pull/1715 這個(gè)pr 來(lái)看應(yīng)該就是把原始的輸入用重計(jì)算的輸入替換之后再代入到之前的梯度計(jì)算公式中算出來(lái)的。
?
https://github.com/BBuf/how-to-optim-algorithm-in-cuda/blob/master/apex/layer_norm_cuda_kernel.cu#L579 這里就對(duì)應(yīng)了對(duì)gamma的梯度,https://github.com/BBuf/how-to-optim-algorithm-in-cuda/blob/master/apex/layer_norm_cuda_kernel.cu#L582C5-L582C5 這里則對(duì)應(yīng)了對(duì)beta的梯度。這里的就等于,公式和代碼實(shí)現(xiàn)都能完整對(duì)應(yīng)上。
0x3. 總結(jié)
這篇文章記錄了筆者在研究大模型訓(xùn)練中偶然見(jiàn)到的一個(gè)Trick的代碼解密過(guò)程,希望對(duì)學(xué)習(xí)cuda的小伙伴有所幫助,謝謝大家。
審核編輯:劉清
-
NVIDIA
+關(guān)注
關(guān)注
14文章
4856瀏覽量
102711 -
RMS
+關(guān)注
關(guān)注
2文章
137瀏覽量
35720 -
python
+關(guān)注
關(guān)注
55文章
4768瀏覽量
84376 -
CUDA
+關(guān)注
關(guān)注
0文章
121瀏覽量
13585 -
GPU芯片
+關(guān)注
關(guān)注
1文章
303瀏覽量
5770
原文標(biāo)題:【BBuf的CUDA筆記】十二,LayerNorm/RMSNorm的重計(jì)算實(shí)現(xiàn)
文章出處:【微信號(hào):GiantPandaCV,微信公眾號(hào):GiantPandaCV】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論