0
  • 聊天消息
  • 系統(tǒng)消息
  • 評(píng)論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會(huì)員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識(shí)你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

LayerNorm/RMSNorm的重計(jì)算實(shí)現(xiàn)

jf_pmFSk4VX ? 來(lái)源:GiantPandaCV ? 2024-01-16 09:55 ? 次閱讀

0x0. 背景

我也是偶然在知乎的一個(gè)問(wèn)題下看到這個(gè)問(wèn)題,大概就是說(shuō)在使用apex的LayerNorm/RMSNorm的時(shí)候可以打開(kāi)這個(gè)api的memory_efficient開(kāi)關(guān),這個(gè)開(kāi)關(guān)可以在速度和精度無(wú)損的情況下節(jié)省網(wǎng)絡(luò)訓(xùn)練的顯存占用。感覺(jué)比較有趣,我就研究了一下,因此也就有了這篇文章。

我去實(shí)測(cè)了一下,單機(jī)8卡A100訓(xùn)練LLama7B,純數(shù)據(jù)并行的情況下打開(kāi)memory_efficient開(kāi)關(guān)相比于不打開(kāi)節(jié)省了大約2個(gè)G的顯存,如果模型繼續(xù)scale up,那么省掉的顯存也會(huì)更多。因此,本文就是對(duì)這個(gè)memory_efficient開(kāi)關(guān)的背后實(shí)現(xiàn)做一個(gè)解讀,另外也會(huì)對(duì)apex里面LayerNorm/RMSNorm本身的cuda kernel實(shí)現(xiàn)做一個(gè)細(xì)節(jié)解讀。

apex的LayerNorm/RMSNorm被實(shí)現(xiàn)成一個(gè)fuse kernel,然后上層使用torch.autograd.Function來(lái)封裝,本文的講解主要以LayerNorm為例子

實(shí)際上RMSNorm和LayerNorm的實(shí)現(xiàn)是共享的,只不過(guò)在kernel內(nèi)部會(huì)區(qū)分一下縮放策略是2個(gè)參數(shù)(LayerNorm的gamma和beta)還是一個(gè)參數(shù)。

classFusedLayerNormAffineFunction(torch.autograd.Function):
@staticmethod
defforward(ctx,input,weight,bias,normalized_shape,eps,memory_efficient=False):
globalfused_layer_norm_cuda
iffused_layer_norm_cudaisNone:
fused_layer_norm_cuda=importlib.import_module("fused_layer_norm_cuda")
ctx.normalized_shape=normalized_shape
ctx.eps=eps
ctx.memory_efficient=memory_efficient
input_=input.contiguous()
weight_=weight.contiguous()
bias_=bias.contiguous()
output,mean,invvar=fused_layer_norm_cuda.forward_affine(
input_,ctx.normalized_shape,weight_,bias_,ctx.eps
)
ifctx.memory_efficient:
ctx.save_for_backward(output,weight_,bias_,None,invvar)
else:
ctx.save_for_backward(input_,weight_,bias_,mean,invvar)
returnoutput

可以看到在非memory_efficient模式下面,ctx.save_for_backward(output, weight_, bias_, None, invvar)保存了用于backward的tensor,包括輸入,權(quán)重,偏置,均值和方差的逆。但在memory_efficient模式下面ctx.save_for_backward(output, weight_, bias_, None, invvar),則是保存了輸出,權(quán)重偏置以及方差的逆。

這個(gè)地方看下你是否會(huì)掉入誤區(qū)?從表面上看,這里也就只省掉了一個(gè)gamma,因?yàn)檩斎牒洼敵鰐ensor的形狀是一樣的,那么這樣還有什么收益呢?背景是,在pre-ln的transformer架構(gòu)里面LayerNorm/RMSNorm之后緊接著是一個(gè)線性投影,無(wú)論是在注意力機(jī)制還是在多層感知機(jī)(mlp)中都是如此,所以輸出Tensor一定要被保存下來(lái)。而在post-ln架構(gòu)中,輸出還會(huì)直接用于殘差連接。然而,在這兩種情況下,LayerNorm/RMSNorm的輸入都不再被使用,所以這里原本的輸入保存變得相當(dāng)多余,因?yàn)槲覀兛梢员4鏌o(wú)論如何都會(huì)被保存的輸出張量。這樣就可以達(dá)到節(jié)省顯存的目的了。

接下來(lái)就詳細(xì)解讀下實(shí)現(xiàn)。

0x1. Apex的LayerNorm前向cuda實(shí)現(xiàn)

https://github.com/NVIDIA/apex/blob/master/csrc/layer_norm_cuda.cpp 這個(gè)文件是基于實(shí)現(xiàn)的LayerNorm cuda kernel使用torch extension模塊導(dǎo)出python接口。

同時(shí)這個(gè)文件還寫了幾個(gè)工具函數(shù),比如compute_n1_n2用來(lái)計(jì)算LayerNorm中非歸一化和歸一化部分的大?。篽ttps://github.com/BBuf/how-to-optim-algorithm-in-cuda/blob/master/apex/layer_norm_cuda.cpp#L7C31-L7C51 ,check_args函數(shù)對(duì)LayerNorm的參數(shù)進(jìn)行檢查:https://github.com/BBuf/how-to-optim-algorithm-in-cuda/blob/master/apex/layer_norm_cuda.cpp#L32C22-L143 。

此外,這個(gè)cpp預(yù)定義了cuda_layer_norm的函數(shù)接口,并且考慮了gamma/beta是否為空。

接下來(lái)就正式對(duì)LayerNorm的前向cuda實(shí)現(xiàn)進(jìn)行解析。

0x1.1 工具函數(shù)

LayerNorm使用Welford算法統(tǒng)計(jì)均值方差,在 https://github.com/NVIDIA/apex/blob/master/csrc/layer_norm_cuda_kernel.cu 寫了一系列kernel實(shí)現(xiàn)中需要用到的工具函數(shù),這些函數(shù)是gpu上用到的。下面對(duì)其簡(jiǎn)單解析一下,另外Welford算法可以看這篇博客的介紹:用Welford算法實(shí)現(xiàn)LN的方差更新(感嘆一下,zzk寫這篇文章的時(shí)候還是萌新,經(jīng)過(guò)2年時(shí)間已經(jīng)成長(zhǎng)為國(guó)內(nèi)頂級(jí)的工程師了,開(kāi)掛般學(xué)習(xí)能力) 。工具函數(shù)包含:cuWelfordOnlineSum,cuChanOnlineSum,cuRMSOnlineSum,cuChanRMSOnlineSum這些,我把自己的原始注釋使用gpt4進(jìn)行了潤(rùn)色,這樣會(huì)顯得更加通俗一些。具體解釋如下:

//這段代碼是個(gè)CUDA函數(shù),名叫cuWelfordOnlineSum,擅長(zhǎng)用Welford算法來(lái)邊收數(shù)據(jù)邊算這些數(shù)據(jù)的平均值和變化范圍(就是均值和方差)。
//用Welford算法來(lái)算這個(gè),特別穩(wěn),不會(huì)因?yàn)閿?shù)據(jù)太多而出錯(cuò),而且每加一個(gè)數(shù)據(jù)就能更新一次均值和方差。
// const U curr:這個(gè)是新來(lái)的數(shù)據(jù)點(diǎn)。
// U& mu:這個(gè)是我們到現(xiàn)在為止算出來(lái)的所有數(shù)據(jù)的平均值。
// U& sigma2:這個(gè)是我們到現(xiàn)在為止算出來(lái)的方差,可以告訴你數(shù)據(jù)變化有多大。
// U& count:這個(gè)記錄了我們到現(xiàn)在處理了多少數(shù)據(jù)點(diǎn)。

template__device__
voidcuWelfordOnlineSum(
constUcurr,
U&mu,
U&sigma2,
U&count)
{
count=count+U(1);//每次調(diào)用這個(gè)函數(shù),就把處理的數(shù)據(jù)數(shù)量加一。
Udelta=curr-mu;//看看新數(shù)據(jù)和現(xiàn)有平均值差多少。
Ulmean=mu+delta/count;//用這個(gè)差值和數(shù)據(jù)總量來(lái)算一個(gè)新的平均值。
mu=lmean;//把這個(gè)新算的平均值記下來(lái)。
Udelta2=curr-lmean;//現(xiàn)在再算一下新數(shù)據(jù)和新平均值的差。
sigma2=sigma2+delta*delta2;//利用這個(gè)新舊平均值的差來(lái)更新方差。
}

//這段代碼是個(gè)CUDA函數(shù),名叫cuChanOnlineSum。它用于處理一種特殊的情況:
//當(dāng)你有兩堆數(shù)據(jù),想要快速算出它們合并后的平均值和方差時(shí),這個(gè)函數(shù)就派上用場(chǎng)了。
// const U muB, sigma2B, countB:這三個(gè)是你新加入的那堆數(shù)據(jù)的平均值、方差和數(shù)據(jù)點(diǎn)數(shù)量。
// U& mu, sigma2, count:這三個(gè)是你之前已經(jīng)有的數(shù)據(jù)的平均值、方差和數(shù)據(jù)點(diǎn)數(shù)量。
//這個(gè)函數(shù)會(huì)更新這些值,讓它們反映出兩堆數(shù)據(jù)合并后的情況。

template__device__
voidcuChanOnlineSum(
constUmuB,
constUsigma2B,
constUcountB,
U&mu,
U&sigma2,
U&count)
{
Udelta=muB-mu;//先算算新數(shù)據(jù)堆和老數(shù)據(jù)堆的平均值差了多少。
UnA=count;//記下當(dāng)前數(shù)據(jù)堆(我們叫它A堆)的大小。
UnB=countB;//看看新來(lái)的那堆數(shù)據(jù)(B堆)有多少個(gè)點(diǎn)。
count=count+countB;//把兩堆數(shù)據(jù)的數(shù)量加起來(lái)。
UnX=count;//這就是合并后總數(shù)據(jù)量的大小。
if(nX>U(0)){
nA=nA/nX;//算一下A堆數(shù)據(jù)在總數(shù)據(jù)中占的比例。
nB=nB/nX;//同理,算一下B堆的比例。
mu=nA*mu+nB*muB;//利用這些比例和各自的平均值,算出總的平均值。
sigma2=sigma2+sigma2B+delta*delta*nA*nB*nX;//然后用一點(diǎn)復(fù)雜的公式,把方差也算出來(lái),這個(gè)公式考慮了兩堆數(shù)據(jù)的方差和它們平均值的差異。
}else{
//如果合并后的總數(shù)是0,那就說(shuō)明兩堆數(shù)據(jù)其實(shí)都是空的,所以把平均值和方差都設(shè)為0。
mu=U(0);
sigma2=U(0);
}
}

//這里定義了一個(gè)名叫cuRMSOnlineSum的CUDA函數(shù),它的主要任務(wù)就是在線實(shí)時(shí)計(jì)算一串?dāng)?shù)據(jù)的平方和。
//你可能會(huì)問(wèn),為什么要算平方和呢?這是因?yàn)槲覀兛梢杂盟鼇?lái)算出均方根(RMS, Root Mean Square),
//均方根是一種描述數(shù)據(jù)波動(dòng)大小的指標(biāo),特別常用于信號(hào)處理領(lǐng)域。
template__device__
voidcuRMSOnlineSum(
constUcurr,
U&sigma2)
{
sigma2=sigma2+curr*curr;//每次函數(shù)被調(diào)用,就把當(dāng)前值的平方加到累計(jì)平方和中。
}

//又定義了一個(gè)名叫cuChanRMSOnlineSum的CUDA函數(shù),這個(gè)家伙的工作就是幫你算兩組數(shù)據(jù)的平方和總和。
//當(dāng)你有兩組數(shù)據(jù),想要快速合并它們的均方根(RMS)時(shí),這個(gè)函數(shù)就能派上用場(chǎng)。
//它其實(shí)是均方根計(jì)算過(guò)程中的一個(gè)環(huán)節(jié),用于處理兩個(gè)獨(dú)立數(shù)據(jù)集的情況。
template__device__
voidcuChanRMSOnlineSum(
constUsigma2B,
U&sigma2)
{
sigma2=sigma2+sigma2B;//這里就簡(jiǎn)單直接了,把第二組數(shù)據(jù)的平方和加到當(dāng)前的累計(jì)值上。
}

這里還有一個(gè)函數(shù)cuWelfordMuSigma2是用來(lái)計(jì)算張量某一維度上的均值(mu)和方差(sigma2)的,它調(diào)用了上面的工具函數(shù),但是這個(gè)函數(shù)我們?cè)趉ernel實(shí)現(xiàn)階段解析,因?yàn)樗枰恍﹌ernel啟動(dòng)的背景。

0x1.2 啟動(dòng)邏輯

先對(duì)kernel啟動(dòng)這部分的代碼進(jìn)行注釋,首先是共享內(nèi)存的結(jié)構(gòu)體定義。

//這段代碼定義了一個(gè)叫做SharedMemory的模板結(jié)構(gòu)體,專門用在CUDA設(shè)備函數(shù)里來(lái)訪問(wèn)所謂的“共享內(nèi)存”。
//在CUDA編程里,共享內(nèi)存是一種特別高效的內(nèi)存類型,非常適合用來(lái)在CUDA的一個(gè)塊(block)內(nèi)的不同線程間共享數(shù)據(jù)。
//這里還包括了針對(duì)float和double類型數(shù)據(jù)的SharedMemory結(jié)構(gòu)體的特化版本。

namespace{
//這是通用的SharedMemory結(jié)構(gòu)體模板。注意,我們通過(guò)在函數(shù)體內(nèi)使用一個(gè)未定義的符號(hào)來(lái)阻止這個(gè)結(jié)構(gòu)體被實(shí)例化,
//這樣如果嘗試用未特化的類型來(lái)編譯這個(gè)結(jié)構(gòu)體,編譯器就會(huì)報(bào)錯(cuò)。
//template
//structSharedMemory
//{
////確保我們不會(huì)編譯任何未特化的類型
//__device__T*getPointer()
//{
//extern__device__voiderror(void);
//error();
//returnNULL;
//}
//};

template
structSharedMemory;

//這是SharedMemory結(jié)構(gòu)體針對(duì)float類型的特化版本。
template<>
structSharedMemory
{
//這個(gè)函數(shù)返回一個(gè)指向共享內(nèi)存的float類型指針。
__device__float*getPointer()
{
//這里聲明了一個(gè)名為s_float的外部共享內(nèi)存數(shù)組,用于存儲(chǔ)float類型的數(shù)據(jù)。
// extern和__shared__關(guān)鍵字表明這個(gè)數(shù)組是在共享內(nèi)存中定義的。
extern__shared__floats_float[];
returns_float;
}
};

//下面是針對(duì)double類型的特化版本,工作方式和float版本相似。
template<>
structSharedMemory
{
__device__double*getPointer()
{
extern__shared__doubles_double[];
returns_double;
}
};
}

然后是Kernel啟動(dòng)的具體邏輯部分:

//這段代碼里,我們定義了一個(gè)CUDA設(shè)備函數(shù)叫做cuApplyLayerNorm_,它的主要任務(wù)是執(zhí)行LayerNorm(層歸一化)。
//層歸一化是深度學(xué)習(xí)中的一個(gè)技巧,用來(lái)讓每一層的輸出更加標(biāo)準(zhǔn)化,有助于模型訓(xùn)練。
//我們定義了三種模板參數(shù):T是輸入數(shù)據(jù)類型,U是中間計(jì)算(比如均值和方差)的類型,V是輸出數(shù)據(jù)類型。
// output_vals, mean, invvar, vals, gamma, beta 這些都是指向不同數(shù)據(jù)的指針。
//在層歸一化中,我們通常把一個(gè)多維數(shù)據(jù)(張量)分為兩部分:一部分用來(lái)做標(biāo)準(zhǔn)化,另一部分保持原樣。
//比如,如果你有一個(gè)[batch_size,channels,height,width]形狀的4D張量,
//而你只想對(duì)最后兩個(gè)維度進(jìn)行層歸一化,那么n1是batch_size * channels,n2是height * width。
template__device__
voidcuApplyLayerNorm_(
V*__restrict__output_vals,
U*__restrict__mean,
U*__restrict__invvar,
constT*__restrict__vals,
constintn1,
constintn2,
constUepsilon,
constV*__restrict__gamma,
constV*__restrict__beta,
boolrms_only
)
{
//基本假設(shè):
// 1) blockDim.x 是 warp 的大?。ㄟ@是一個(gè)CUDA的技術(shù)細(xì)節(jié))。
// 2)輸入的張量數(shù)據(jù)在內(nèi)存中是連續(xù)的。
//
//這段代碼遍歷n1維度,每次處理一個(gè)i1索引。
//假設(shè)每個(gè)CUDA線程塊的x維度等于warp大小,確保數(shù)據(jù)處理是高效的。
//這里一個(gè)線程可能要處理多行,所以我們用gridDim.y來(lái)控制步長(zhǎng)。(因?yàn)間ridDim.x=1)
for(autoi1=blockIdx.y;i1shared;
U*buf=shared.getPointer();//創(chuàng)建一個(gè) SharedMemory 實(shí)例用于處理類型 U 的數(shù)據(jù)。
Umu,sigma2;//這里mu和sigma2分別代表均值和方差,我們接下來(lái)要計(jì)算它們。
//調(diào)用 cuWelfordMuSigma2 函數(shù)計(jì)算給定索引 i1 處的均值(mu)和方差(sigma2)。
cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf,rms_only);

//定位到當(dāng)前 i1 索引處的輸入和輸出的起始位置。
constT*lvals=vals+i1*n2;
V*ovals=output_vals+i1*n2;
//計(jì)算逆方差 c_invvar,這是層歸一化中一個(gè)關(guān)鍵的步驟。
Uc_invvar=rsqrt(sigma2+epsilon);
//計(jì)算每個(gè) CUDA 塊的線程總數(shù)(numx)和當(dāng)前線程的一維索引(thrx)。
constintnumx=blockDim.x*blockDim.y;
constintthrx=threadIdx.x+threadIdx.y*blockDim.x;
//如果提供了gamma和beta參數(shù),或者我們只是在做RMS計(jì)算,我們會(huì)用一種特別的方式來(lái)計(jì)算輸出值。
if(gamma!=NULL&&(beta!=NULL||rms_only)){
for(inti=thrx;i(lvals[i]);
if(!rms_only){
//標(biāo)準(zhǔn)化當(dāng)前值,然后用gamma和beta進(jìn)行調(diào)整。
ovals[i]=gamma[i]*static_cast(c_invvar*(curr-mu))+beta[i];
}else{
////如果是RMS模式,我們稍微簡(jiǎn)化計(jì)算過(guò)程。
ovals[i]=gamma[i]*static_cast(c_invvar*curr);
}

}
}
//否則,如果沒(méi)有提供gamma和beta,我們就直接用計(jì)算出的均值和逆方差來(lái)進(jìn)行標(biāo)準(zhǔn)化。
else{
for(inti=thrx;i(lvals[i]);
if(!rms_only){
//直接進(jìn)行標(biāo)準(zhǔn)化計(jì)算。
ovals[i]=static_cast(c_invvar*(curr-mu));
}else{
//// RMS模式下的簡(jiǎn)化計(jì)算。
ovals[i]=static_cast(c_invvar*curr);
}
}
}
//在每個(gè) CUDA 塊中,僅由一個(gè)線程(線程(0,0))更新均值和逆方差。
if(threadIdx.x==0&&threadIdx.y==0){
if(!rms_only){
mean[i1]=mu;
}
invvar[i1]=c_invvar;
}
//用于同步塊內(nèi)的所有線程。
__syncthreads();
}
}

//對(duì)上個(gè)函數(shù)的參數(shù)透?jìng)鳎贿^(guò)rms_only設(shè)為False
template__global__
voidcuApplyLayerNorm(
V*__restrict__output_vals,
U*__restrict__mean,
U*__restrict__invvar,
constT*__restrict__vals,
constintn1,
constintn2,
constUepsilon,
constV*__restrict__gamma,
constV*__restrict__beta
)
{
cuApplyLayerNorm_(output_vals,mean,invvar,vals,n1,n2,epsilon,gamma,beta,false);
}

//kernel啟動(dòng)代碼,設(shè)置線程塊和線程數(shù)
template
voidHostApplyLayerNorm(
V*output,
U*mean,
U*invvar,
constT*input,
intn1,
intn2,
doubleepsilon,
constV*gamma,
constV*beta
)
{
// threads和blocks定義了CUDA內(nèi)核的線程和塊的維度。這里,每個(gè)線程塊有32×4的線程,而塊的數(shù)量由n1和GPU設(shè)備的最大網(wǎng)格大小限制決定。
autostream=at::getCurrentCUDAStream().stream();
constdim3threads(32,4,1);
constuint64_tmaxGridY=at::getCurrentDeviceProperties()->maxGridSize[1];
constdim3blocks(1,std::min((uint64_t)n1,maxGridY),1);
//這段代碼計(jì)算內(nèi)核函數(shù)需要多少共享內(nèi)存。如果threads.y大于1,它會(huì)根據(jù)U類型的大小分配足夠的內(nèi)存。
intnshared=
threads.y>1?
threads.y*sizeof(U)+(threads.y/2)*sizeof(U):
0;
//最后,函數(shù)使用cuApplyLayerNorm kernel來(lái)執(zhí)行實(shí)際的LayerNorm操作。
// kernel函數(shù)的調(diào)用使用了之前計(jì)算的線程塊和線程配置,以及共享內(nèi)存大小和CUDA流。
cuApplyLayerNorm<<>>(
output,mean,invvar,input,n1,n2,U(epsilon),gamma,beta);
}

這段代碼包含了kernel的啟動(dòng)邏輯,包括設(shè)置block的個(gè)數(shù)以及每個(gè)block中的線程排布方式,然后在cuApplyLayerNorm_里面有一個(gè)跨線程網(wǎng)格的大循環(huán)作用在n1維度,每個(gè)線程可能會(huì)處理多行數(shù)據(jù)。而在每一行數(shù)據(jù)的處理上,調(diào)用了cuWelfordMuSigma2 函數(shù)計(jì)算給定索引 i1 處的均值(mu)和方差(sigma2),并隨后在n2維度上來(lái)計(jì)算LayerNorm的輸出,同時(shí)會(huì)在每個(gè)Block的線程(0, 0)更新cuWelfordMuSigma2算出來(lái)的均值和方差(這里的記錄的實(shí)際上是方差的逆)。

0x1.3 kernel實(shí)現(xiàn)

從上面的分析可知,整個(gè)LayerNorm實(shí)現(xiàn)的核心就是cuWelfordMuSigma2函數(shù),下面對(duì)這個(gè)函數(shù)進(jìn)行解析。

//`cuWelfordMuSigma2`是一個(gè)CUDA設(shè)備函數(shù),旨在高效計(jì)算張量某一特定維度上的均值(mu)和方差(sigma2)。
//它基于Welford算法實(shí)現(xiàn),以提高數(shù)值穩(wěn)定性。此外,該函數(shù)支持僅計(jì)算均方根(RMS)作為一種操作模式。
//模板參數(shù):定義了處理張量值(T)和執(zhí)行計(jì)算(U)時(shí)使用的數(shù)據(jù)類型。
// const T*__restrict__ vals:指向張量數(shù)據(jù)的指針。
// const int n1, n2:指定張量的維度,其中n1是參與計(jì)算的維度的大小,n2是被約減的維度的大小。
// const int i1:當(dāng)前正在處理的n1維度上的特定索引。
// U& mu, sigma2:用于存儲(chǔ)計(jì)算得出的均值和方差。
// U* buf:指向用于線程間通訊的共享內(nèi)存緩沖區(qū)的指針。
// bool rms_only:一個(gè)標(biāo)志,用于指示是否僅計(jì)算RMS(為true時(shí))或同時(shí)計(jì)算均值和方差(為false時(shí))。
template__device__
voidcuWelfordMuSigma2(
constT*__restrict__vals,
constintn1,
constintn2,
constinti1,
U&mu,
U&sigma2,
U*buf,
boolrms_only)
{
//前提條件:
// 1) blockDim.x 等于 warp 的大小。
// 2)輸入的張量在內(nèi)存中連續(xù)存儲(chǔ)。
// 3)有足夠的共享內(nèi)存可用,大小為 2*blockDim.y*sizeof(U)+ blockDim.y*sizeof(int)。
//
//在 n2 維度上計(jì)算方差和均值。
//初始化 count, mu, 和 sigma2 為零。
Ucount=U(0);
mu=U(0);
sigma2=U(0);
//確保處理的 i1 索引在張量的有效范圍內(nèi)。
if(i1(lvals[l+k]);
//根據(jù) rms_only 標(biāo)志調(diào)用相應(yīng)的函數(shù)來(lái)更新均值和方差或僅更新平方和(用于計(jì)算 RMS)。
if(!rms_only){
cuWelfordOnlineSum(curr,mu,sigma2,count);
}else{
cuRMSOnlineSum(curr,sigma2);
}
}
}
//這個(gè)循環(huán)處理了之前在步長(zhǎng)為 4*numx 的循環(huán)中未處理的張量元素。每個(gè)線程獨(dú)立處理它們剩余的部分。
for(;l(lvals[l]);
if(!rms_only){
cuWelfordOnlineSum(curr,mu,sigma2,count);
}else{
cuRMSOnlineSum(curr,sigma2);
}
}
//在同一個(gè)warp內(nèi)進(jìn)行歸約操作。
for(intl=0;l<=?4;??++l)?{
??????//?是在 CUDA 設(shè)備上進(jìn)行 warp 內(nèi)部數(shù)據(jù)交換的關(guān)鍵部分。
??????//?這行代碼用于確定在一個(gè) warp(32個(gè)線程)內(nèi),每個(gè)線程應(yīng)該從哪個(gè)“l(fā)ane”(即其他線程)獲取數(shù)據(jù)。
??????//?(1<(muB,sigma2B,countB,mu,sigma2,count);
}else{
cuChanRMSOnlineSum(sigma2B,sigma2);
}
}
//threadIdx.x==0hascorrectvaluesforeachwarp
//inter-warpreductions
//檢查是否有多個(gè) warp。如果 blockDim.y 大于 1,則表示塊中有多個(gè) warp 需要進(jìn)行reduce操作。
if(blockDim.y>1){
//為方差和均值的reduce操作分配共享內(nèi)存。ubuf 用于存儲(chǔ)方差和均值,ibuf 用于存儲(chǔ)計(jì)數(shù)。
U*ubuf=(U*)buf;
U*ibuf=(U*)(ubuf+blockDim.y);
//這個(gè)循環(huán)是對(duì) warp 間的reduce操作進(jìn)行分層合并。
for(intoffset=blockDim.y/2;offset>0;offset/=2){
//upperhalfofwarpswritetoshared
//確保只有部分線程(warp 的上半部分)將其計(jì)算的結(jié)果寫入共享內(nèi)存。
if(threadIdx.x==0&&threadIdx.y>=offset&&threadIdx.y(muB,sigma2B,countB,mu,sigma2,count);
}else{
cuChanRMSOnlineSum(sigma2B,sigma2);
}
}
__syncthreads();
}
//threadIdx.x=0&&threadIdx.y==0onlythreadthathascorrectvalues
//最終的結(jié)果由塊內(nèi)的第一個(gè)線程(threadIdx.x ==0&& threadIdx.y ==0)計(jì)算并寫入共享內(nèi)存。
if(threadIdx.x==0&&threadIdx.y==0){
if(!rms_only){
ubuf[0]=mu;
}
ubuf[1]=sigma2;
}
__syncthreads();
//如果不是只計(jì)算 RMS,則還需要更新均值 mu。
if(!rms_only){
mu=ubuf[0];
}
//計(jì)算最終的方差。
sigma2=ubuf[1]/U(n2);
//don'tcareaboutfinalvalueofcount,weknowcount==n2
}
//如果塊中只有一個(gè) warp(blockDim.y == 1),則通過(guò) WARP_SHFL 直接在 warp 內(nèi)進(jìn)行數(shù)據(jù)交換和更新。
else{
if(!rms_only){
mu=WARP_SHFL(mu,0);
}
sigma2=WARP_SHFL(sigma2/U(n2),0);
}
}

cuWelfordMuSigma2函數(shù)就是在n2維度上使用工具函數(shù)章節(jié)的Weleford方法來(lái)完成均值和方差的計(jì)算,然后這里還借助了共享內(nèi)存來(lái)做warp內(nèi)和warp間的reduce,最終得到全局的均值和方差。

前向的kernel就分析到這里,大家如果想對(duì)LayerNorm的優(yōu)化做進(jìn)一步的了解,推薦看一下OneFlow的SoftMax和LayerNorm優(yōu)化文章。CUDA優(yōu)化之LayerNorm性能優(yōu)化實(shí)踐(https://zhuanlan.zhihu.com/p/443026261) ,這篇文章也是講解了LayerNorm的前向優(yōu)化流程,文章開(kāi)頭有一張性能的圖:

adaadb1a-b3ae-11ee-8b88-92fbcf53809c.png

實(shí)際上在大模型時(shí)代,我們的隱藏層維度已經(jīng)越來(lái)越大了,所以我們?cè)趯?shí)際訓(xùn)練的時(shí)候,OneFlow版本的kernel相比于apex的layerNorm在13B之類的模型訓(xùn)練里就拿不到明顯收益了。而在CV中,由于做LayerNorm的維度可能相對(duì)小一些,所以相比于apex的LayerNorm就可以取得明顯加速。

0x2. Apex的LayerNorm反向cuda實(shí)現(xiàn)(memory_efficient相關(guān)計(jì)算)

在apex的LayerNorm反向?qū)崿F(xiàn)時(shí)我們不僅要關(guān)注它的cuda kernel是怎么寫的,還要關(guān)注memory_efficient打開(kāi)時(shí)是如何根據(jù)輸出來(lái)計(jì)算梯度的。我們知道LayerNorm需要對(duì)輸入,gamma,beta都計(jì)算梯度,介于篇幅原因,這里對(duì)實(shí)現(xiàn)得最復(fù)雜的gamma/beta的反向過(guò)程進(jìn)行走讀。

0x2.1 啟動(dòng)邏輯

這里從kernel的啟動(dòng)邏輯開(kāi)始梳理:

//這是一個(gè)模板函數(shù),支持不同的數(shù)據(jù)類型:T(輸入數(shù)據(jù)類型)、
// U(通常用于中間計(jì)算的數(shù)據(jù)類型,默認(rèn)為float)、V(輸出數(shù)據(jù)類型,默認(rèn)與T相同)。
//參數(shù)包括輸出梯度(dout)、均值(mean)、方差倒數(shù)(invvar)、輸入或輸出的PyTorch張量(input_or_output)、
//兩個(gè)維度參數(shù)(n1、n2)、gamma和beta參數(shù)、用于數(shù)值穩(wěn)定的epsilon、輸入梯度(grad_input)、
// gamma梯度(grad_gamma)和beta梯度(grad_beta)、以及一個(gè)指示是否優(yōu)化內(nèi)存使用的布爾值(memory_efficient)。
template
voidHostLayerNormGradient(
constV*dout,
constU*mean,
constU*invvar,
at::Tensor*input_or_output,
intn1,
intn2,
constV*gamma,
constV*beta,
doubleepsilon,
T*grad_input,
V*grad_gamma,
V*grad_beta,
boolmemory_efficient
)
{
//獲取當(dāng)前CUDA流以用于后續(xù)的CUDA內(nèi)核調(diào)用。
autostream=at::getCurrentCUDAStream().stream();

//如果gamma和beta不為NULL,函數(shù)會(huì)計(jì)算它們的梯度。
//這涉及兩個(gè)CUDA內(nèi)核的調(diào)用:cuComputePartGradGammaBeta和cuComputeGradGammaBeta。
if(gamma!=NULL&&beta!=NULL){
//computegrad_gamma(j)andgrad_beta(j)
// part_size是分塊計(jì)算梯度時(shí)的部分大小。
constintpart_size=16;
// threads2定義了每個(gè)CUDA線程塊中的線程數(shù)量(32×4×1)。
constdim3threads2(32,4,1);
// blocks2定義了CUDA網(wǎng)格中的塊數(shù)量,其中,n2維度被分成多個(gè)塊,以確保每個(gè)塊可以處理n2中的一部分。
constdim3blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
//這部分代碼計(jì)算用于CUDA內(nèi)核的共享內(nèi)存大小。nshared2_a和nshared2_b是基于線程和塊維度的兩種不同共享內(nèi)存大小估算。
constintnshared2_a=2*sizeof(U)*threads2.y*threads2.y*(threads2.x+1);
constintnshared2_b=threads2.x*threads2.y*sizeof(U);
//最終選擇較大的一個(gè)估算值作為實(shí)際的共享內(nèi)存大?。╪shared2)。
constintnshared2=nshared2_a>nshared2_b?nshared2_a:nshared2_b;
//note(mkozuki):Icanhardcodepart_grad_gamma'sdtypeasfloatgiventhat
//the`cuda_layer_norm_gradient`doesn'tsupportdouble.
//根據(jù)輸入或輸出張量的數(shù)據(jù)類型決定局部梯度張量part_grad_gamma和part_grad_beta的數(shù)據(jù)類型。
//如果輸入或輸出是半精度浮點(diǎn)數(shù)(Half)或BFloat16,則使用單精度浮點(diǎn)數(shù)(Float);否則,使用輸入或輸出的相同數(shù)據(jù)類型。
constautopart_grad_dtype=
(input_or_output->scalar_type()==at::Half||input_or_output->scalar_type()==at::BFloat16)?
at::Float:
input_or_output->scalar_type();
//創(chuàng)建兩個(gè)新的PyTorch張量part_grad_gamma和part_grad_beta,用于存儲(chǔ)gamma和beta的局部梯度計(jì)算結(jié)果。
at::Tensorpart_grad_gamma=at::empty({part_size,n2},input_or_output->options().dtype(part_grad_dtype));
at::Tensorpart_grad_beta=at::empty_like(part_grad_gamma);
//使用BOOL_SWITCH宏處理memory_efficient參數(shù),以決定是否使用內(nèi)存高效版本的CUDA內(nèi)核。
//調(diào)用cuComputePartGradGammaBeta內(nèi)核計(jì)算gamma和beta的梯度。
//這個(gè)內(nèi)核函數(shù)接收必要的輸入?yún)?shù),并將梯度結(jié)果寫入part_grad_gamma和part_grad_beta張量。
BOOL_SWITCH(memory_efficient,MemoryEfficient,[&]{
autokernel=&cuComputePartGradGammaBeta;
kernel<<>>(
dout,
input_or_output->DATA_PTR(),
n1,n2,
mean,
invvar,
U(epsilon),
gamma,
beta,
part_grad_gamma.DATA_PTR(),
part_grad_beta.DATA_PTR(),
epsilon,
false);
});

//定義了每個(gè)CUDA線程塊中的線程數(shù)量(32×8×1)。
constdim3threads3(32,8,1);
//定義了CUDA網(wǎng)格中的塊數(shù)量。在這里,n2維度被分成多個(gè)塊,每個(gè)塊的大小由threads2.x(之前定義的線程數(shù)量)確定。
constdim3blocks3((n2+threads2.x-1)/threads2.x,1,1);
//這行代碼計(jì)算了cuComputeGradGammaBeta內(nèi)核所需的共享內(nèi)存大小。它基于threads3線程塊的維度和數(shù)據(jù)類型U的大小。
constintnshared3=threads3.x*threads3.y*sizeof(U);
//kernel接收局部梯度張量(part_grad_gamma和part_grad_beta)、塊大?。╬art_size)、
//維度參數(shù)(n1和n2)和指向梯度輸出的指針(grad_gamma和grad_beta)。
cuComputeGradGammaBeta<<>>(
part_grad_gamma.DATA_PTR(),
part_grad_beta.DATA_PTR(),
part_size,
n1,n2,
grad_gamma,
grad_beta,
false);
}
...
}

這里省略了計(jì)算輸入梯度的啟動(dòng)代碼,只看計(jì)算gamma和beta梯度的代碼??梢园l(fā)現(xiàn),這里對(duì)gamma和beta的梯度進(jìn)行計(jì)算時(shí)使用了分塊計(jì)算的方式,首先會(huì)調(diào)用cuComputePartGradGammaBeta這個(gè)kernel計(jì)算出一個(gè)部分gamma和部分beta,也就是part_grad_gamma和part_grad_beta,需要注意這個(gè)kernel開(kāi)啟的線程塊為:const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1),其中part_size=16,此外每個(gè)線程塊中的線程排布為:const dim3 threads2(32,4,1),即每個(gè)線程塊有128個(gè)線程。我們可以簡(jiǎn)單算一下block2的大小,threads2.x=32,那么blocks2=(n2/32,16,1),也就是一共會(huì)有n2/2個(gè)線程塊。

使用cuComputePartGradGammaBeta計(jì)算完局部gamma和beta的grad之后,會(huì)調(diào)用cuComputeGradGammaBeta這個(gè)kernel來(lái)匯總?cè)值膅amma和beta的梯度。這里開(kāi)啟的線程塊為:const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1),而每個(gè)線程塊里面有256個(gè)線程,排布為const dim3 threads3(32,8,1)。

現(xiàn)在了解了線程塊的組織方式就需要去kernel實(shí)現(xiàn)里面對(duì)應(yīng)看一下具體是怎么計(jì)算的。

0x2.2 kernel計(jì)算邏輯

首先來(lái)看分段計(jì)算gamma和beta梯度的kernel實(shí)現(xiàn),注釋如下:

// part_size是分塊計(jì)算梯度時(shí)的部分大小。
//constintpart_size=16;
// threads2定義了每個(gè)CUDA線程塊中的線程數(shù)量(32×4×1)。
//constdim3threads2(32,4,1);
// blocks2定義了CUDA網(wǎng)格中的塊數(shù)量,其中,n2維度被分成多個(gè)塊,以確保每個(gè)塊可以處理n2中的一部分。
//constdim3blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
//->
//blockDim.x=32,blockDim.y=4,gridDim.y=16
//假設(shè)n1=4,n2=256,并且當(dāng)前是第一個(gè)線程塊
template__global__
voidcuComputePartGradGammaBeta(
constV*__restrict__dout,
constT*__restrict__input_or_output,
constintn1,
constintn2,
constU*__restrict__mean,
constU*__restrict__invvar,
Uepsilon,
constV*__restrict__gamma,
constV*__restrict__beta,
U*part_grad_gamma,
U*part_grad_beta,
constdoubleeps,
boolrms_only)
{
// numsegs_n1計(jì)算n1維度(4)被分成多少段。使用blockDim.y*blockDim.y(16)作為分段大小。
//帶入值:numsegs_n1 =(4 + 16 - 1)/ 16 = 1。
constintnumsegs_n1=(n1+blockDim.y*blockDim.y-1)/(blockDim.y*blockDim.y);
// segs_per_block計(jì)算每個(gè)線程塊要處理的段數(shù)。
//帶入值:segs_per_block =(1 + 16 - 1)/ 16 = 1。
constintsegs_per_block=(numsegs_n1+gridDim.y-1)/gridDim.y;
//這些行計(jì)算當(dāng)前線程塊開(kāi)始和結(jié)束處理n1維度的索引
//i1_beg和i1_beg_plus_one相差segs_per_block*blockDim.y*blockDim.y=1*4*4=16
//帶入blockIdx.y =0:i1_beg =0* 1 * 4 * 4 =0, i1_beg_plus_one = 1 * 1 * 4 * 4 = 16,i1_end = min(16, 4)= 4
constinti1_beg=blockIdx.y*segs_per_block*blockDim.y*blockDim.y;
constinti1_beg_plus_one=(blockIdx.y+1)*segs_per_block*blockDim.y*blockDim.y;
constinti1_end=i1_beg_plus_oneshared;
U*buf=shared.getPointer();//bufhasatleastblockDim.x*blockDim.y*blockDim.y+(blockDim.y-1)*(blockDim.x/blockDim.y)elements
U*warp_buf1=(U*)buf;//大小是31*4*4=496
U*warp_buf2=warp_buf1+blockDim.y*blockDim.y*row_stride;//大小是3*(32/4)=24

//computepartialsumsfromstridedinputs
//dothistoincreasenumberofloadsinflight
cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input_or_output,dout,i1_end,n2,mean,invvar,gamma,beta,eps,rms_only);
// for循環(huán)處理每個(gè)數(shù)據(jù)塊(由i1_beg和i1_end確定)。
//它在數(shù)據(jù)塊之間以步幅blockDim.y*blockDim.y迭代,允許不同的線程塊處理不同的數(shù)據(jù)區(qū)域。
for(inti1_block=i1_beg+blockDim.y*blockDim.y;i1_block(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input_or_output,dout,i1_end,n2,mean,invvar,gamma,beta,eps,rms_only);
}
//確保在所有線程完成其加載和處理操作之前,沒(méi)有線程會(huì)繼續(xù)執(zhí)行后續(xù)的操作。
__syncthreads();
//inter-warpreductions
//sumwithineachwarp
//這部分代碼執(zhí)行內(nèi)部歸約,計(jì)算每個(gè)warp內(nèi)部的部分和。
// acc1和acc2分別用于累積來(lái)自warp_buf1和warp_buf2的值。這些緩沖區(qū)包含之前步驟計(jì)算的中間結(jié)果。
Uacc1=U(0);
Uacc2=U(0);
//內(nèi)部循環(huán)對(duì)于blockDim.y內(nèi)的每一行進(jìn)行累加,if (!rms_only)條件檢查是否需要執(zhí)行特定的分支邏輯。
//需要特別注意,這個(gè)累加實(shí)際上是在列方向上也就是n2維度,在n2維度上一個(gè)線程負(fù)責(zé)計(jì)算blockDim.y列
for(intk=0;k1;offset/=2){
//在每次迭代中,只有threadIdx.y小于當(dāng)前offset的線程會(huì)參與計(jì)算,這樣可以避免重復(fù)的工作。
if(threadIdx.y

在理解這段代碼之前,有一個(gè)大前提,那就是這里的訪問(wèn)方式是n1是和blockDim.y綁定的,而n2是和blockDim.x綁定的,而且以二維矩陣的角度來(lái)看,n1是在列方向,而n2是在行的方向。然后const int row_stride = blockDim.x+1這一行是對(duì)共享內(nèi)存進(jìn)行padding避免Bank Conflict的,而在計(jì)算時(shí)對(duì)共享內(nèi)存的訪問(wèn)就是按照列來(lái)訪問(wèn),徹底避免bank conflict。

這也是為什么cuLoadWriteStridedInputs和cuLoadAddStridedInputs函數(shù)名中有一個(gè)Strided,這也暗示了它們的訪問(wèn)模式是跨stride的。剩下的部分其實(shí)和前向就比較類似了,做warp內(nèi)和warp間的reduce。

另外一個(gè)值得注意的點(diǎn)是在cuLoadWriteStridedInputs和cuLoadAddStridedInputs計(jì)算時(shí),會(huì)根據(jù)memory_efficient開(kāi)關(guān)選擇不同的計(jì)算公式,分別從輸入和輸出來(lái)計(jì)算出梯度,達(dá)到kernel內(nèi)部重計(jì)算的目的。

//這段代碼定義了一個(gè)名為cuLoadWriteStridedInputs的CUDA設(shè)備函數(shù)模板,用于在計(jì)算LayerNorm的梯度時(shí),
//從輸入張量中加載數(shù)據(jù)并進(jìn)行必要的計(jì)算,將結(jié)果存儲(chǔ)在 warp 緩沖區(qū)中。這個(gè)函數(shù)支持內(nèi)存高效模式(MemoryEfficient)。
//模板參數(shù) T, U, V 代表不同的數(shù)據(jù)類型。
// bool MemoryEfficient 用于選擇是否采用內(nèi)存高效的方式處理數(shù)據(jù)。
//__device__表明這是一個(gè) CUDA 設(shè)備函數(shù)。
//函數(shù)參數(shù)包括各種用于LayerNorm梯度計(jì)算的數(shù)據(jù),
//如輸入/輸出張量、梯度張量 dout、均值 mean、逆方差 invvar、縮放參數(shù) gamma、偏移參數(shù) beta 等。
template__device__
voidcuLoadWriteStridedInputs(
constinti1_block,
constintthr_load_row_off,
constintthr_load_col_off,
constinti2_off,
constintrow_stride,
U*warp_buf1,
U*warp_buf2,
constT*input_or_output,
constV*dout,
constinti1_end,
constintn2,
constU*__restrict__mean,
constU*__restrict__invvar,
constV*__restrict__gamma,
constV*__restrict__beta,
constdoubleeps,
boolrms_only
)
{
//計(jì)算 i1,表示當(dāng)前處理的行索引。
inti1=i1_block+thr_load_row_off;
if(i1(input_or_output[load_idx]);
Ucurr_dout=static_cast(dout[load_idx]);
//根據(jù) rms_only 和 MemoryEfficient 的值,使用不同的公式計(jì)算梯度,并將結(jié)果存儲(chǔ)在 warp 緩沖區(qū)中。
if(!rms_only){
warp_buf1[write_idx]=curr_dout;
if(MemoryEfficient){
Ucurr_beta=static_cast(beta[i2]);
warp_buf2[write_idx]=curr_dout*(c_h-curr_beta)/static_cast(clamp_by_magnitude(gamma[i2],eps));
}else{
warp_buf2[write_idx]=curr_dout*(c_h-mean[i1])*invvar[i1];
}
}else{
if(MemoryEfficient){
warp_buf2[write_idx]=curr_dout*(c_h)/static_cast(clamp_by_magnitude(gamma[i2],eps));
}else{
warp_buf2[write_idx]=curr_dout*(c_h)*invvar[i1];
}
}
}else{
//對(duì)于超出 n2 范圍的索引,將相應(yīng)的 warp 緩沖區(qū)位置設(shè)置為0。
if(!rms_only){
warp_buf1[write_idx]=U(0);
}
warp_buf2[write_idx]=U(0);
}
}
}else{
//對(duì)于超出 n1 范圍的索引,也將相應(yīng)的 warp 緩沖區(qū)位置設(shè)置為0。
for(intk=0;k

執(zhí)行完cuComputePartGradGammaBeta這個(gè)kernel之后,它的輸出part_grad_gamma和part_grad_beta分別以行為n2列為n1的內(nèi)存視角保存了LayerNorm的局部均值和方差的梯度。

接下來(lái)會(huì)使用cuComputeGradGammaBeta這個(gè)kernel來(lái)計(jì)算全局的均值和方差的梯度,由于局部計(jì)算的時(shí)候分塊大小是16,而每個(gè)線程負(fù)責(zé)了4行的計(jì)算,那么這里還需要累積16/4=4次,以得到當(dāng)前行的所有局部梯度的和。

//blockDim.x=n2/32,blockDim.y=1
//threadDim.x=32,threadDim.y=8
template__global__
voidcuComputeGradGammaBeta(
constU*part_grad_gamma,
constU*part_grad_beta,
constintpart_size,
constintn1,
constintn2,
V*grad_gamma,
V*grad_beta,
boolrms_only)
{
//sumpartialgradientsforgammaandbeta
SharedMemoryshared;
U*buf=shared.getPointer();
//計(jì)算每個(gè)線程的全局索引i2,用于確定它在n2維度上的位置。
inti2=blockIdx.x*blockDim.x+threadIdx.x;
//如果線程索引i2小于n2的大小,該線程會(huì)參與計(jì)算。
if(i2=1;offset/=2){
//tophalfwritetosharedmemory
//在這個(gè)歸約階段,線程首先將其累加結(jié)果寫入共享內(nèi)存,然后從共享內(nèi)存讀取并繼續(xù)累加。
if(threadIdx.y>=offset&&threadIdx.y

注意,for (int offset = blockDim.y/2; offset >= 1; offset /= 2) 這個(gè)循環(huán)包起來(lái)的代碼在這里不會(huì)工作,因?yàn)檫@個(gè)kernel的啟動(dòng)設(shè)置中 blockDim.y=1。另外,我們知道輸入的數(shù)據(jù)已經(jīng)是寫到全局內(nèi)存里面的了,已經(jīng)是同步之后的了,然后每個(gè)線程累積4次這個(gè)過(guò)程也是從global memory里面先讀再計(jì)算最后寫回全局內(nèi)存,所以確實(shí)不需要再reduce了。

關(guān)于memory_efficient開(kāi)關(guān)打開(kāi)時(shí)的梯度計(jì)算公式,按照 https://github.com/NVIDIA/apex/pull/1715 這個(gè)pr 來(lái)看應(yīng)該就是把原始的輸入用重計(jì)算的輸入替換之后再代入到之前的梯度計(jì)算公式中算出來(lái)的。

adb8e214-b3ae-11ee-8b88-92fbcf53809c.png?adc91d3c-b3ae-11ee-8b88-92fbcf53809c.png

https://github.com/BBuf/how-to-optim-algorithm-in-cuda/blob/master/apex/layer_norm_cuda_kernel.cu#L579 這里就對(duì)應(yīng)了對(duì)gamma的梯度,https://github.com/BBuf/how-to-optim-algorithm-in-cuda/blob/master/apex/layer_norm_cuda_kernel.cu#L582C5-L582C5 這里則對(duì)應(yīng)了對(duì)beta的梯度。這里的就等于,公式和代碼實(shí)現(xiàn)都能完整對(duì)應(yīng)上。

0x3. 總結(jié)

這篇文章記錄了筆者在研究大模型訓(xùn)練中偶然見(jiàn)到的一個(gè)Trick的代碼解密過(guò)程,希望對(duì)學(xué)習(xí)cuda的小伙伴有所幫助,謝謝大家。






審核編輯:劉清

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場(chǎng)。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問(wèn)題,請(qǐng)聯(lián)系本站處理。 舉報(bào)投訴
  • NVIDIA
    +關(guān)注

    關(guān)注

    14

    文章

    4856

    瀏覽量

    102711
  • RMS
    RMS
    +關(guān)注

    關(guān)注

    2

    文章

    137

    瀏覽量

    35720
  • python
    +關(guān)注

    關(guān)注

    55

    文章

    4768

    瀏覽量

    84376
  • CUDA
    +關(guān)注

    關(guān)注

    0

    文章

    121

    瀏覽量

    13585
  • GPU芯片
    +關(guān)注

    關(guān)注

    1

    文章

    303

    瀏覽量

    5770

原文標(biāo)題:【BBuf的CUDA筆記】十二,LayerNorm/RMSNorm的重計(jì)算實(shí)現(xiàn)

文章出處:【微信號(hào):GiantPandaCV,微信公眾號(hào):GiantPandaCV】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。

收藏 人收藏

    評(píng)論

    相關(guān)推薦

    【大規(guī)模語(yǔ)言模型:從理論到實(shí)踐】- 每日進(jìn)步一點(diǎn)點(diǎn)

    :相比LayerNormRMSNorm去除了平移部分,只保留了縮放部分,從而減少了計(jì)算均值和平移系數(shù)的部分,訓(xùn)練速度更快。 Deep Normalization(DeepNorm) 原理:由微軟提出
    發(fā)表于 05-31 19:54

    【求助】--想做個(gè)云端體重計(jì)的項(xiàng)目,求達(dá)人指點(diǎn)

    能把體重計(jì)的數(shù)據(jù)通過(guò)藍(lán)牙或類似方式傳到手機(jī)上。希望能結(jié)交到懂相關(guān)技術(shù)的朋友。
    發(fā)表于 05-27 14:57

    重計(jì)算程序

    快速方便的計(jì)算出你的體重指數(shù)。
    發(fā)表于 01-11 19:09

    求大神寫一個(gè)用FPGA,Cyclone 2芯片的稱重計(jì)

    求大神寫一個(gè)用FPGA,Cyclone 2芯片的稱重計(jì),采用了稱重傳感器,HX711AD模塊,想達(dá)到在LCD1602上顯示,還有超過(guò)量程有指示燈亮的功能,最好還有鍵盤輸入單價(jià)可以顯示總價(jià)的計(jì)算功能
    發(fā)表于 05-27 19:00

    求基于單片機(jī)的智能體重計(jì)的Proteus仿真圖

    要求:體重計(jì)用藍(lán)牙與體重計(jì)連接,并在手機(jī)端統(tǒng)計(jì)若干次的體重。
    發(fā)表于 05-07 16:36

    如何去實(shí)現(xiàn)一種基于51單片機(jī)的HX711稱重計(jì)的設(shè)計(jì)?

    HX711是什么?HX711有哪些優(yōu)點(diǎn)?HX711的管腳有哪些?其功能是什么?如何去實(shí)現(xiàn)一種基于51單片機(jī)的HX711稱重計(jì)的設(shè)計(jì)?
    發(fā)表于 07-19 07:32

    金屬材料單重計(jì)算 軟件

    金屬材料單重計(jì)算 軟件 金屬材料單重計(jì)算 軟件 金屬材料單重計(jì)算 軟件
    發(fā)表于 09-26 23:03

    基于AT89S51的垃圾稱重計(jì)費(fèi)控制系統(tǒng)

    本文設(shè)計(jì)了一種基于AT89S51單片機(jī)的垃圾稱重計(jì)費(fèi)控制系統(tǒng)。與其他控制系統(tǒng)相比,單片機(jī)系統(tǒng)具有體積小巧、成本低廉等優(yōu)勢(shì)。
    發(fā)表于 08-17 14:21 ?2767次閱讀
    基于AT89S51的垃圾稱<b class='flag-5'>重計(jì)</b>費(fèi)控制系統(tǒng)

    基于通過(guò)熱電偶傳感器來(lái)提高稱重計(jì)的測(cè)量精度設(shè)計(jì)

    重計(jì)應(yīng)用在從浴室到工廠車間的各種場(chǎng)合中,滿量程從小于250磅到上千噸。稱重計(jì)都是基于薄膜金屬應(yīng)變片加上精心設(shè)計(jì)的金屬桿結(jié)構(gòu),這些應(yīng)變片連接成傳統(tǒng)的電橋結(jié)構(gòu)以實(shí)現(xiàn)最大的靈敏度。它通常可以提供1~4mV/V的滿量程輸出,而采用5V
    發(fā)表于 09-07 15:42 ?1612次閱讀
    基于通過(guò)熱電偶傳感器來(lái)提高稱<b class='flag-5'>重計(jì)</b>的測(cè)量精度設(shè)計(jì)

    20個(gè)電氣實(shí)用小工具負(fù)荷、電阻算、無(wú)功補(bǔ)償、變壓器等計(jì)算軟件

    軟件2005,電氣設(shè)備容量計(jì)算軟件2005,動(dòng)力照明系統(tǒng)電纜設(shè)計(jì),多功能計(jì)算器,負(fù)荷計(jì)算,焊接材料選擇,金屬材料單重計(jì)算,金屬材料單重計(jì)算
    發(fā)表于 11-07 16:41 ?65次下載

    電池修復(fù)技術(shù):比重與比重計(jì)制作的說(shuō)明

    電池內(nèi)部雜質(zhì)(特別是鐵離子)對(duì)電瓶的危害很大,會(huì)造成電瓶自放電,縮短自身壽命。因此,在注入硫酸和水時(shí),要注意純度。 比重計(jì)是測(cè)電解液的工具,但市售的比重計(jì)測(cè)量時(shí)需要較多電解液,難以使用。買光學(xué)比重
    發(fā)表于 05-18 17:19 ?1108次閱讀
    電池修復(fù)技術(shù):比重與比<b class='flag-5'>重計(jì)</b>制作的說(shuō)明

    聚乙烯比重計(jì)的主要特點(diǎn)有哪些

      聚乙烯比重計(jì)采用阿基米得原理浮力法、水中置換法,準(zhǔn)確、直讀量測(cè)數(shù)值。 適用于:聚乙烯、密封件、 粉末冶金、成品、含油率、有機(jī)溶劑、塑膠管材、橡膠塑料、薄膜、電纜、玻璃工業(yè)、液體、添加助劑、新材料研究實(shí)驗(yàn)室。一體成型的設(shè)計(jì),大大簡(jiǎn)化了操作,又能保證測(cè)量的攜帶和測(cè)試。
    發(fā)表于 09-29 13:37 ?310次閱讀

    塑料顆粒比重計(jì)的作用和優(yōu)勢(shì)

    、CCC、VDE等各國(guó)標(biāo)準(zhǔn)規(guī)范。 塑料顆粒比重計(jì)是目前使用群體zui多的數(shù)顯密度測(cè)量?jī)x器,精度千分之一,使用水當(dāng)介質(zhì),僅二個(gè)步驟,即可顯示密度值。與傳統(tǒng)測(cè)量工程塑料顆粒的比重測(cè)試儀器相比,本機(jī)無(wú)需人工計(jì)算,操作方便省時(shí)、測(cè)量、。 塑料顆粒比
    發(fā)表于 10-08 16:32 ?1324次閱讀

    具有身體成分測(cè)量功能的體重計(jì)參考設(shè)計(jì)

    電子發(fā)燒友網(wǎng)站提供《具有身體成分測(cè)量功能的體重計(jì)參考設(shè)計(jì).zip》資料免費(fèi)下載
    發(fā)表于 11-08 10:34 ?3次下載
    具有身體成分測(cè)量功能的體<b class='flag-5'>重計(jì)</b>參考設(shè)計(jì)

    SNx5DPHY440SS CSI-2/DSI DPHY 重計(jì)時(shí)器數(shù)據(jù)表

    電子發(fā)燒友網(wǎng)站提供《SNx5DPHY440SS CSI-2/DSI DPHY 重計(jì)時(shí)器數(shù)據(jù)表.pdf》資料免費(fèi)下載
    發(fā)表于 06-25 11:07 ?0次下載
    SNx5DPHY440SS CSI-2/DSI DPHY <b class='flag-5'>重計(jì)</b>時(shí)器數(shù)據(jù)表