0
  • 聊天消息
  • 系統(tǒng)消息
  • 評論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會(huì)員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

LLM的長度外推淺談

深度學(xué)習(xí)自然語言處理 ? 來源:NLP工作站 ? 2023-07-28 17:37 ? 次閱讀

一、NBCE

NBCE:使用樸素貝葉斯擴(kuò)展LLM的Context處理長度

蘇神最早提出的擴(kuò)展LLM的context方法,基于bayes啟發(fā)得到的公式:

fd1b2440-2d29-11ee-815d-dac502259ad0.pngfd312d9e-2d29-11ee-815d-dac502259ad0.png

問答下實(shí)測確實(shí)不錯(cuò),在較長context下的閱讀理解還算好用。

局限性是,無序性,即無法識別Context的輸入順序,這在續(xù)寫故事等場景可能表現(xiàn)欠佳,做一些依賴每個(gè)context生成答案,比如提取文檔摘要,效果較差。

outputs=model(input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
use_cache=True,
past_key_values=past_key_values
)
past_key_values=outputs.past_key_values

#=====核心代碼開始=====
beta=0.25
probas=torch.nn.functional.softmax(outputs.logits[:,-1],dim=-1)
logits=probas.log()
k=(probas*logits).sum(dim=-1)[1:].argmax()+1
logits_max=logits[k]
logits_uncond=logits[0]
logits=(1+beta)*logits_max-beta*logits_uncond
#=====核心代碼結(jié)束=====

#構(gòu)建分布,采樣
tau=0.01#tau=1是標(biāo)準(zhǔn)的隨機(jī)采樣,tau->0則是貪心搜索
probas=torch.nn.functional.softmax(logits[None]/tau,dim=-1)
next_tokens=torch.multinomial(probas,num_samples=1).squeeze(1)

此處代碼,圖片,文本均選自科學(xué)空間。

二、線性內(nèi)插

llama基于rotary embedding在2048長度上預(yù)訓(xùn)練,該方法通過將position壓縮到0~2048之間,從而達(dá)到長度外推的目的。

longchat將模型微調(diào)為上下文長度外擴(kuò)為16384,壓縮比為 8。例如,position_ids = 10000 的 token 變?yōu)閜osition_ids = 10000 / 8 = 1250,相鄰 token 10001 變?yōu)?10001 / 8 = 1250.125

該方法的缺陷是需要進(jìn)行一定量的微調(diào),讓模型來適應(yīng)這種改變。

importtorch
importtransformers
importtransformers.models.llama.modeling_llama
fromeinopsimportrearrange

fromfunctoolsimportpartial

classCondenseRotaryEmbedding(torch.nn.Module):
def__init__(self,dim,ratio,max_position_embeddings=2048,base=10000,device=None):
super().__init__()
inv_freq=1.0/(base**(torch.arange(0,dim,2).float().to(device)/dim))
self.register_buffer("inv_freq",inv_freq)

#Buildheretomake`torch.jit.trace`work.
self.ratio=ratio
max_position_embeddings*=ratio
print(f"CondensingPositionalembeddingsfrom{max_position_embeddings}to{max_position_embeddings//ratio}")
self.max_seq_len_cached=max_position_embeddings
t=torch.arange(self.max_seq_len_cached,device=self.inv_freq.device,dtype=self.inv_freq.dtype)/ratio
freqs=torch.einsum("i,j->ij",t,self.inv_freq)
#Differentfrompaper,butitusesadifferentpermutationinordertoobtainthesamecalculation
emb=torch.cat((freqs,freqs),dim=-1)
dtype=torch.get_default_dtype()
self.register_buffer("cos_cached",emb.cos()[None,None,:,:].to(dtype),persistent=False)
self.register_buffer("sin_cached",emb.sin()[None,None,:,:].to(dtype),persistent=False)

defforward(self,x,seq_len=None):
#x:[bs,num_attention_heads,seq_len,head_size]
#This`if`blockisunlikelytoberunafterwebuildsin/cosin`__init__`.Keepthelogicherejustincase.
ifseq_len>self.max_seq_len_cached:
self.max_seq_len_cached=seq_len
t=torch.arange(self.max_seq_len_cached,device=x.device,dtype=self.inv_freq.dtype)/self.ratio
freqs=torch.einsum("i,j->ij",t,self.inv_freq)
#Differentfrompaper,butitusesadifferentpermutationinordertoobtainthesamecalculation
emb=torch.cat((freqs,freqs),dim=-1).to(x.device)
self.register_buffer("cos_cached",emb.cos()[None,None,:,:].to(x.dtype),persistent=False)
self.register_buffer("sin_cached",emb.sin()[None,None,:,:].to(x.dtype),persistent=False)
return(
self.cos_cached[:,:,:seq_len,...].to(dtype=x.dtype),
self.sin_cached[:,:,:seq_len,...].to(dtype=x.dtype),
)

defreplace_llama_with_condense(ratio):
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding=partial(CondenseRotaryEmbedding,ratio=ratio)

三、NTK-Aware Scaled RoPE

RoPE是一種β進(jìn)制編碼//spaces.ac.cn/archives/9675

fd4b1808-2d29-11ee-815d-dac502259ad0.png

有意思的解釋一下,RoPE 的行為就像一個(gè)時(shí)鐘。12小時(shí)時(shí)鐘基本上是一個(gè)維度為 3、底數(shù)為 60 的 RoPE。因此,每秒鐘,分針轉(zhuǎn)動(dòng) 1/60 分鐘,每分鐘,時(shí)針轉(zhuǎn)動(dòng) 1/60。

現(xiàn)在,如果將時(shí)間減慢 4 倍,那就是二使用的線性RoPE 縮放。不幸的是,現(xiàn)在區(qū)分每一秒,因?yàn)楝F(xiàn)在秒針幾乎每秒都不會(huì)移動(dòng)。

因此,如果有人給你兩個(gè)不同的時(shí)間,僅相差一秒,你將無法從遠(yuǎn)處區(qū)分它們。NTK-Aware RoPE 擴(kuò)展不會(huì)減慢時(shí)間。一秒仍然是一秒,但它會(huì)使分鐘減慢 1.5 倍,將小時(shí)減慢 2 倍。

這樣,您可以將 90 分鐘容納在一個(gè)小時(shí)中,將 24 小時(shí)容納在半天中。

所以現(xiàn)在你基本上有了一個(gè)可以測量 129.6k 秒而不是 43.2k 秒的時(shí)鐘。由于在查看時(shí)間時(shí)不需要精確測量時(shí)針,因此與秒相比,更大程度地縮放小時(shí)至關(guān)重要。

不想失去秒針的精度,但可以承受分針甚至?xí)r針的精度損失。

importtransformers

old_init=transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__
defntk_scaled_init(self,dim,max_position_embeddings=2048,base=10000,device=None):

#Themethodisjustthesethreelines
max_position_embeddings=16384
a=8#Alphavalue
base=base*a**(dim/(dim-2))#Basechangeformula

old_init(self,dim,max_position_embeddings,base,device)
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__=ntk_scaled_init

四、Dynamically Scaled RoPE

fd56100a-2d29-11ee-815d-dac502259ad0.png

對于上面的方法二、三,都涉及到一個(gè)超參數(shù)α,用于調(diào)節(jié)縮放比例,該方法是通過序列長度動(dòng)態(tài)選擇正確的比例參數(shù),效果可以看上圖。

對于線性插值,前 2k 上下文的精確位置值,然后在模型逐個(gè)生成標(biāo)記時(shí)重新計(jì)算每個(gè)新序列長度的位置向量。本質(zhì)上,將比例設(shè)置為原始模型上下文長度/當(dāng)前序列長度。

對于動(dòng)態(tài) NTK,α 的縮放設(shè)置為 (α * 當(dāng)前序列長度 / 原始模型上下文長度) - (α - 1)。隨著序列長度的增加動(dòng)態(tài)縮放超參數(shù)。

importmath
importtorch

classLlamaDynamicScaledRotaryEmbedding(torch.nn.Module):
def__init__(self,dim,max_position_embeddings=2048,base=10000,ntk=False,device=None):
super().__init__()
self.ntk=ntk
self.base=base
self.dim=dim
self.max_position_embeddings=max_position_embeddings
inv_freq=1.0/(base**(torch.arange(0,dim,2).float().to(device)/dim))
self.register_buffer("inv_freq",inv_freq)

#Buildheretomake`torch.jit.trace`work.
self.max_seq_len_cached=max_position_embeddings
t=torch.arange(self.max_seq_len_cached,device=self.inv_freq.device,dtype=self.inv_freq.dtype)
freqs=torch.einsum("i,j->ij",t,self.inv_freq)
#Differentfrompaper,butitusesadifferentpermutationinordertoobtainthesamecalculation
emb=torch.cat((freqs,freqs),dim=-1)
dtype=torch.get_default_dtype()
self.register_buffer("cos_cached",emb.cos()[None,None,:,:].to(dtype),persistent=False)
self.register_buffer("sin_cached",emb.sin()[None,None,:,:].to(dtype),persistent=False)

defforward(self,x,seq_len=None):
#x:[bs,num_attention_heads,seq_len,head_size]
#This`if`blockisunlikelytoberunafterwebuildsin/cosin`__init__`.Keepthelogicherejustincase.
ifseq_len>self.max_seq_len_cached:
self.max_seq_len_cached=seq_len
ifself.ntk:
base=self.base*((self.ntk*seq_len/self.max_position_embeddings)-(self.ntk-1))**(self.dim/(self.dim-2))
inv_freq=1.0/(base**(torch.arange(0,self.dim,2).float().to(x.device)/self.dim))
self.register_buffer("inv_freq",inv_freq)
t=torch.arange(self.max_seq_len_cached,device=x.device,dtype=self.inv_freq.dtype)
ifnotself.ntk:
t*=self.max_position_embeddings/seq_len
freqs=torch.einsum("i,j->ij",t,self.inv_freq)
#Differentfrompaper,butitusesadifferentpermutationinordertoobtainthesamecalculation
emb=torch.cat((freqs,freqs),dim=-1).to(x.device)
self.register_buffer("cos_cached",emb.cos()[None,None,:,:].to(x.dtype),persistent=False)
self.register_buffer("sin_cached",emb.sin()[None,None,:,:].to(x.dtype),persistent=False)
return(
self.cos_cached[:,:,:seq_len,...].to(dtype=x.dtype),
self.sin_cached[:,:,:seq_len,...].to(dtype=x.dtype),
)

五、consistent of Dynamically Scaled RoPE

fd799f3e-2d29-11ee-815d-dac502259ad0.png

方法四存在一個(gè)問題是,因?yàn)棣潦莿?dòng)態(tài)的,因?yàn)榻獯a是有cache的,所以,在生成第100個(gè)token時(shí),算的α和第200個(gè)token時(shí),算的α?xí)r不一致的。fd9f8a78-2d29-11ee-815d-dac502259ad0.png

query和key的rotation base不一致,正確的應(yīng)該時(shí)這樣

fda853ec-2d29-11ee-815d-dac502259ad0.png

importmath
fromtypingimportList,Optional,Tuple,Union

importtorch
importtorch.nn.functionalasF
importtorch.utils.checkpoint
fromtorchimportnn
fromtransformers.models.llama.modeling_llamaimportrepeat_kv,apply_rotary_pos_emb
fromtransformers.models.llama.modeling_llamaimportLlamaAttention

defforward(
self,
hidden_states:torch.Tensor,
attention_mask:Optional[torch.Tensor]=None,
position_ids:Optional[torch.LongTensor]=None,
past_key_value:Optional[Tuple[torch.Tensor]]=None,
output_attentions:bool=False,
use_cache:bool=False,
)->Tuple[torch.Tensor,Optional[torch.Tensor],Optional[Tuple[torch.Tensor]]]:
bsz,q_len,_=hidden_states.size()

ifself.pretraining_tp>1:
key_value_slicing=(self.num_key_value_heads*self.head_dim)//self.pretraining_tp
query_slices=self.q_proj.weight.split((self.num_heads*self.head_dim)//self.pretraining_tp,dim=0)
key_slices=self.k_proj.weight.split(key_value_slicing,dim=0)
value_slices=self.v_proj.weight.split(key_value_slicing,dim=0)

query_states=[F.linear(hidden_states,query_slices[i])foriinrange(self.pretraining_tp)]
query_states=torch.cat(query_states,dim=-1)

key_states=[F.linear(hidden_states,key_slices[i])foriinrange(self.pretraining_tp)]
key_states=torch.cat(key_states,dim=-1)

value_states=[F.linear(hidden_states,value_slices[i])foriinrange(self.pretraining_tp)]
value_states=torch.cat(value_states,dim=-1)

else:
query_states=self.q_proj(hidden_states)
key_states=self.k_proj(hidden_states)
value_states=self.v_proj(hidden_states)

query_states=query_states.view(bsz,q_len,self.num_heads,self.head_dim).transpose(1,2)
key_states=key_states.view(bsz,q_len,self.num_key_value_heads,self.head_dim).transpose(1,2)
value_states=value_states.view(bsz,q_len,self.num_key_value_heads,self.head_dim).transpose(1,2)

kv_seq_len=key_states.shape[-2]
ifpast_key_valueisnotNone:
kv_seq_len+=past_key_value[0].shape[-2]
cos,sin=self.rotary_emb(value_states,seq_len=kv_seq_len)

ifpast_key_valueisnotNone:
#reusekw/oRoPE
key_states=torch.cat([past_key_value[0],key_states],dim=2)

#applyRoPEafterretrievingallkeysandqueries
query_states,rotated_key_states=apply_rotary_pos_emb(query_states,key_states,cos,sin,position_ids)

ifpast_key_valueisnotNone:
#reusev,self_attention
value_states=torch.cat([past_key_value[1],value_states],dim=2)

past_key_value=(key_states,value_states)ifuse_cacheelseNone#cachethekeyw/oRoPE

#repeatk/vheadsifn_kv_heads1:
attn_output=attn_output.split(self.hidden_size//self.pretraining_tp,dim=2)
o_proj_slices=self.o_proj.weight.split(self.hidden_size//self.pretraining_tp,dim=1)
attn_output=sum([F.linear(attn_output[i],o_proj_slices[i])foriinrange(self.pretraining_tp)])
else:
attn_output=self.o_proj(attn_output)

ifnotoutput_attentions:


attn_weights=None

returnattn_output,attn_weights,past_key_value


defreplace_llama_attn_with_consistent_ntk_rope():
LlamaAttention.forward=forward





審核編輯:劉清

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問題,請聯(lián)系本站處理。 舉報(bào)投訴
  • 解碼器
    +關(guān)注

    關(guān)注

    9

    文章

    1107

    瀏覽量

    40440
  • LLM
    LLM
    +關(guān)注

    關(guān)注

    0

    文章

    247

    瀏覽量

    279

原文標(biāo)題:淺談LLM的長度外推

文章出處:【微信號:zenRRan,微信公眾號:深度學(xué)習(xí)自然語言處理】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。

收藏 人收藏

    評論

    相關(guān)推薦

    對比解碼在LLM上的應(yīng)用

    為了改進(jìn)LLM的推理能力,University of California聯(lián)合Meta AI實(shí)驗(yàn)室提出將Contrastive Decoding應(yīng)用于多種任務(wù)的LLM方法。實(shí)驗(yàn)表明,所提方法能有效改進(jìn)LLM的推理能力。讓我們走進(jìn)
    發(fā)表于 09-21 11:37 ?502次閱讀
    對比解碼在<b class='flag-5'>LLM</b>上的應(yīng)用

    餓了么確認(rèn)收購百度外賣!最快本周收購,百度外賣為何會(huì)變成百度的棄子?

     餓了么和百度外度一直就以競爭對手的形式出現(xiàn),然而兩者之間相斗爭總有敗的一方,近日餓了么和百度外賣又一次上了熱搜榜。
    發(fā)表于 08-21 14:52 ?827次閱讀

    餓了么正式宣布收購百度外賣 后者人員架構(gòu)不變以獨(dú)立品牌運(yùn)營

    8月24日下午消息,餓了么剛剛正式宣布收購百度外賣。合并完成后,百度外賣成為餓了么的全資子公司。百度外賣仍以獨(dú)立的品牌和運(yùn)營體系發(fā)展,包括管理層在內(nèi)的人員架構(gòu)保持不變。
    發(fā)表于 08-24 16:53 ?731次閱讀

    餓了么正式宣布收購百度外賣 內(nèi)部郵件曝光

    8月24日下午消息,餓了么剛剛正式宣布收購百度外賣,隨后,百度外賣內(nèi)部郵件曝光。郵件表示,合并后,百度外賣仍以獨(dú)立的品牌和運(yùn)營體系發(fā)展,包括管理層在內(nèi)的人員架構(gòu)保持不變。
    發(fā)表于 08-24 16:59 ?794次閱讀

    LLM性能的主要因素

    現(xiàn)在是2023年5月,截止目前,網(wǎng)絡(luò)上已經(jīng)開源了眾多的LLM,如何用較低的成本,判斷LLM的基礎(chǔ)性能,選到適合自己任務(wù)的LLM,成為一個(gè)關(guān)鍵。 本文會(huì)涉及以下幾個(gè)問題: 影響LLM性能
    的頭像 發(fā)表于 05-22 15:26 ?1495次閱讀
    <b class='flag-5'>LLM</b>性能的主要因素

    中國研究人員提出StructGPT,提高LLM對結(jié)構(gòu)化數(shù)據(jù)的零樣本推理能力

    盡管結(jié)構(gòu)化數(shù)據(jù)的體量往往非常巨大,但不可能容納輸入提示中的所有數(shù)據(jù)記錄(例如,ChatGPT 的最大上下文長度為 4096)。將結(jié)構(gòu)化數(shù)據(jù)線性化為 LLM 可以輕松掌握的語句是解決此問題的簡單方法。工具操作技術(shù)激勵(lì)他們增強(qiáng) LLM
    的頭像 發(fā)表于 05-24 16:02 ?2740次閱讀
    中國研究人員提出StructGPT,提高<b class='flag-5'>LLM</b>對結(jié)構(gòu)化數(shù)據(jù)的零樣本推理能力

    使用MLC-LLM支持RWKV-5推理的過程思考

    LLM的理解比較有限,從代碼實(shí)現(xiàn)的角度來說,RWKV的狀態(tài)和KV Cache不同,不依賴序列長度,這讓RWKV模型在各種長度下運(yùn)行內(nèi)存和運(yùn)行速度都是趨于穩(wěn)定的,所以我感覺工程價(jià)值是比基于Transformer架構(gòu)比如Llama
    的頭像 發(fā)表于 11-19 15:58 ?854次閱讀
    使用MLC-<b class='flag-5'>LLM</b>支持RWKV-5推理的過程思考

    如何利用位置編碼實(shí)現(xiàn)長度外

    無論是縮放位置索引還是修改基地,所有token都變得彼此更接近,這將損害LLM區(qū)分相近token的位置順序的能力。結(jié)合他們對RoPE的波長的觀察,存在一些波長比預(yù)訓(xùn)練的上下文窗口長的維度,NTK-by-parts插值的作者建議完全不插值較高的頻率維度。
    發(fā)表于 01-08 09:58 ?400次閱讀
    如何利用位置編碼實(shí)現(xiàn)<b class='flag-5'>長度外</b><b class='flag-5'>推</b>?

    LLM推理加速新范式!推測解碼(Speculative Decoding)最新綜述

    低下(->每個(gè)token的生成都需要重復(fù)讀寫LLM的巨量參數(shù)),并且序列的生成時(shí)間隨著序列長度的增加而線性增加。
    的頭像 發(fā)表于 01-29 15:54 ?1929次閱讀
    <b class='flag-5'>LLM</b>推理加速新范式!推測解碼(Speculative Decoding)最新綜述

    100%在樹莓派上執(zhí)行的LLM項(xiàng)目

    ChatGPT的人性口語化回復(fù)相信許多人已體驗(yàn)過,也因此掀起一波大型語言模型(Large Language Model, LLM)熱潮,LLM即ChatGPT背后的主運(yùn)作技術(shù),但LLM運(yùn)作需要龐大運(yùn)算力,因此目前多是在云端(Cl
    的頭像 發(fā)表于 02-29 16:29 ?1114次閱讀
    100%在樹莓派上執(zhí)行的<b class='flag-5'>LLM</b>項(xiàng)目

    hdmi線纜長度根據(jù)什么決定選擇

    可以達(dá)到30米,這足以支持一個(gè)1080p的視頻和一個(gè)8聲道的音頻信號。 然而,需要注意的是,對于4K分辨率的HDMI線纜,其長度應(yīng)小于15米,以確保最佳的圖像和音頻質(zhì)量。 在選擇HDMI線纜時(shí),除了考慮長度外,還需要注意線纜的規(guī)格和品質(zhì)。HDMI線纜有多種規(guī)格,包括HDM
    的頭像 發(fā)表于 06-06 11:44 ?1175次閱讀

    什么是LLM?LLM的工作原理和結(jié)構(gòu)

    隨著人工智能技術(shù)的飛速發(fā)展,大型語言模型(Large Language Model,簡稱LLM)逐漸成為自然語言處理(NLP)領(lǐng)域的研究熱點(diǎn)。LLM以其強(qiáng)大的文本生成、理解和推理能力,在文本
    的頭像 發(fā)表于 07-02 11:45 ?4298次閱讀

    LLM模型的應(yīng)用領(lǐng)域

    在本文中,我們將深入探討LLM(Large Language Model,大型語言模型)的應(yīng)用領(lǐng)域。LLM是一種基于深度學(xué)習(xí)的人工智能技術(shù),它能夠理解和生成自然語言文本。近年來,隨著計(jì)算能力的提高
    的頭像 發(fā)表于 07-09 09:52 ?293次閱讀

    llm模型和chatGPT的區(qū)別

    LLM(Large Language Model)是指大型語言模型,它們是一類使用深度學(xué)習(xí)技術(shù)構(gòu)建的自然語言處理(NLP)模型。LLM模型可以處理各種語言任務(wù),如文本生成、文本分類、機(jī)器翻譯等。目前
    的頭像 發(fā)表于 07-09 09:55 ?457次閱讀

    llm模型有哪些格式

    LLM(Large Language Model,大型語言模型)是一種深度學(xué)習(xí)模型,主要用于處理自然語言處理(NLP)任務(wù)。LLM模型的格式多種多樣,以下是一些常見的LLM模型格式
    的頭像 發(fā)表于 07-09 09:59 ?332次閱讀