本例從零開(kāi)始基于transformers庫(kù)逐模塊搭建和解讀Llama模型源碼(中文可以翻譯成羊駝)。
并且訓(xùn)練它來(lái)實(shí)現(xiàn)一個(gè)有趣的實(shí)例:兩數(shù)之和。
輸入輸出類(lèi)似如下:
輸入:"12345+54321="
輸出:"66666"
我們把這個(gè)任務(wù)當(dāng)做一個(gè)文本生成任務(wù)來(lái)進(jìn)行。輸入是一個(gè)序列的上半部分,輸出其下半部分.
這和文本生成的輸入輸出結(jié)構(gòu)是類(lèi)似的,所以可以用Llama來(lái)做。
目前大部分開(kāi)源LLM模型都是基于transformers庫(kù)來(lái)做的,它們的結(jié)構(gòu)大部分都和Llama大同小異。
俗話說(shuō),魔鬼隱藏在細(xì)節(jié)中,深入理解Llama模型的的源碼細(xì)節(jié),將會(huì)幫助你打通和開(kāi)源LLM模型相關(guān)的基礎(chǔ)原理(如旋轉(zhuǎn)位置編碼以及長(zhǎng)度外推),并讓你熟悉各種參數(shù)的配置和使用(如past_key_value,attention_mask的使用等等)。
一,準(zhǔn)備數(shù)據(jù)
import?random import?numpy?as?np import?torch from?torch.utils.data?import?Dataset,DataLoader #?定義字典 words?=?', , ,1,2,3,4,5,6,7,8,9,0,+,=' vocab?=?{word:?i?for?i,?word?in?enumerate(words.split(','))} vocab_r?=?[k?for?k,?v?in?vocab.items()]?#反查詞典
#兩數(shù)相加數(shù)據(jù)集 def?get_data(min_length=10,max_length=20): ????#?定義詞集合 ????words?=?['0',?'1',?'2',?'3',?'4',?'5',?'6',?'7',?'8',?'9'] ????#?每個(gè)詞被選中的概率 ????p?=?np.array([7,?5,?5,?7,?6,?5,?7,?6,?5,?7]) ????p?=?p?/?p.sum() ????#?隨機(jī)采樣n1個(gè)詞作為s1 ????n1?=?random.randint(min_length,?max_length) ????s1?=?np.random.choice(words,?size=n1,?replace=True,?p=p) ????s1?=?s1.tolist() ????#?隨機(jī)采樣n2個(gè)詞作為s2 ????n2?=?random.randint(min_length,?max_length) ????s2?=?np.random.choice(words,?size=n2,?replace=True,?p=p) ????s2?=?s2.tolist() ????#?x等于s1和s2字符上的相加 ????x?=?s1?+?['+']?+?s2?+?['='] ???? ????#?y等于s1和s2數(shù)值上的相加 ????y?=?int(''.join(s1))?+?int(''.join(s2)) ????y?=?list(str(y)) ???? ????#?加上首尾符號(hào) ????x?=?['']?+?x? ????y?=??y?+?[' '] ???? ????return?x,y x,y?=?get_data()? print(''.join(x)+''.join(y)," ")
3914835626735057733+318829464988=3914835945564522721
#?定義數(shù)據(jù)集 class?TwoSumDataset(torch.utils.data.Dataset): ????def?__init__(self,size?=?100000,?min_length=10,max_length=20): ????????super(Dataset,?self).__init__() ????????self.size?=?size ????????self.min_length=min_length ????????self.max_length=max_length ????def?__len__(self): ????????return?self.size ????def?__getitem__(self,?i): ????????x,y?=?self.get(i) ???????? ????????#?編碼成token ????????context_ids?=?[vocab[i]?for?i?in?x] ????????target_ids?=?[vocab[i]?for?i?in?y] ???????? ????????input_ids?=?context_ids?+?target_ids ???????? ????????#-100標(biāo)志位后面會(huì)在計(jì)算loss時(shí)會(huì)被忽略不貢獻(xiàn)損失,我們集中優(yōu)化target部分生成的loss ????????labels?=?[-100]*len(context_ids)+?target_ids ????????masks?=?[0?if?t==vocab['']?else?1?for?t?in?input_ids] ???????? ????????example?=?{'input_ids':input_ids, ??????????????????'labels':labels,'attention_mask':masks} ???????? ????????return?example ???? ????def?get(self,i): ????????return?get_data(self.min_length,self.max_length) ???? ???? ????def?show_example(self,example): ????????input_ids,labels?=?example['input_ids'],example['labels'] ????????x?=?''.join([vocab_r[a]?for?a,b?in?zip(input_ids,labels)?if?b==-100]) ????????y?=?''.join([vocab_r[a]?for?a,b?in?zip(input_ids,labels)?if?b!=-100]) ????????print(x+y) ???????? ???????? ???? ds_train?=?TwoSumDataset(size?=?100000,min_length=10,max_length=20) ds_val?=?TwoSumDataset(size?=?10000,min_length=10,max_length=20) example?=?ds_train[0] ds_train.show_example(example)
12878683929048906366+11274414130675477=12889958343179581843
def?data_collator(examples:?list): ????len_ids?=?[len(example["input_ids"])?for?example?in?examples] ????longest?=?max(len_ids)?#之后按照batch中最長(zhǎng)的input_ids進(jìn)行padding ???? ????input_ids?=?[] ????labels_list?=?[] ????masks_list?=?[] ???? ????for?length,?example?in?sorted(zip(len_ids,?examples),?key=lambda?x:?-x[0]): ????????ids?=?example["input_ids"] ????????labs?=?example["labels"] ????????masks?=?example['attention_mask'] ???????? ????????ids?=?[vocab['']]?*?(longest?-?length)+ids? ????????labs?=?[-100]?*?(longest?-?length)+labs ????????masks?=?[0]*(longest?-?length)+masks ???????? ????????input_ids.append(torch.LongTensor(ids)) ????????labels_list.append(torch.LongTensor(labs)) ????????masks_list.append(torch.LongTensor(masks)) ?????????? ????input_ids?=?torch.stack(input_ids) ????labels?=?torch.stack(labels_list) ????attention_mask?=?torch.stack(masks_list) ????return?{ ????????"input_ids":?input_ids, ????????"labels":?labels, ????????"attention_mask":attention_mask ????} #?數(shù)據(jù)加載器 dl_train?=?DataLoader(dataset=ds_train, ?????????batch_size=200, ?????????drop_last=True, ?????????shuffle=True, ?????????collate_fn?=?data_collator???????? ????????) dl_val?=?DataLoader(dataset=ds_val, ?????????batch_size=200, ?????????drop_last=True, ?????????shuffle=False, ?????????collate_fn?=?data_collator?? ????????)
for?batch?in?dl_train: ????break?
batch?
{'input_ids': tensor([[ 1, 11, 6, ..., 7, 11, 2], [ 0, 1, 6, ..., 5, 4, 2], [ 0, 1, 7, ..., 8, 8, 2], ..., [ 0, 0, 0, ..., 10, 11, 2], [ 0, 0, 0, ..., 12, 3, 2], [ 0, 0, 0, ..., 11, 12, 2]]), 'labels': tensor([[-100, -100, -100, ..., 7, 11, 2], [-100, -100, -100, ..., 5, 4, 2], [-100, -100, -100, ..., 8, 8, 2], ..., [-100, -100, -100, ..., 10, 11, 2], [-100, -100, -100, ..., 12, 3, 2], [-100, -100, -100, ..., 11, 12, 2]]), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1], [0, 1, 1, ..., 1, 1, 1], [0, 1, 1, ..., 1, 1, 1], ..., [0, 0, 0, ..., 1, 1, 1], [0, 0, 0, ..., 1, 1, 1], [0, 0, 0, ..., 1, 1, 1]])}
?
?
二,定義模型
下面,我們會(huì)像搭積木建城堡那樣從低往高地構(gòu)建LLaMA模型。
先構(gòu)建4個(gè)基礎(chǔ)組件:旋轉(zhuǎn)位置編碼,多頭注意力、前饋網(wǎng)絡(luò)、層歸一化。類(lèi)似用最基礎(chǔ)的積木塊搭建了 墻壁,房頂,房門(mén),窗戶 這樣的模塊。
然后用這4個(gè)基礎(chǔ)組件構(gòu)建中間成品: 解碼層。類(lèi)似用基礎(chǔ)組件構(gòu)建了房間。
接著用多個(gè)中間成品解碼層的堆疊組裝成了LlamaModel完整模型,相當(dāng)于通過(guò)構(gòu)建多個(gè)房間建成了城堡的主體結(jié)構(gòu)。
最后我們?cè)贚lamaModel基礎(chǔ)上設(shè)計(jì)了兩種不同的輸出head,一種是語(yǔ)言模型Head,得到了LlamaForCausalLM,可用于文本生成。
另外一種是分類(lèi)head,得到了LlamaForSequenceClassification,可用于文本分類(lèi)。
相當(dāng)于我們?cè)诔潜ぶ黧w結(jié)構(gòu)完成的基礎(chǔ)上設(shè)計(jì)了兩種不同的裝修風(fēng)格,一種是加裝了一些游樂(lè)設(shè)施以便用于商業(yè)活動(dòng),另一種則是加裝了一些武器以便用于軍事活動(dòng)。
1, 旋轉(zhuǎn)位置編碼: RoPE (使用旋轉(zhuǎn)矩陣實(shí)現(xiàn)的絕對(duì)位置編碼,可以起到相對(duì)位置編碼的效果)
2, 多頭注意力: LlamaAttention (用于融合不同token之間的信息)
3, 前饋網(wǎng)絡(luò): LlamaMLP (用于逐位置將多頭注意力融合后的信息進(jìn)行高維映射變換)
4, 層歸一化: LlamaRMSNorm (用于穩(wěn)定輸入,相當(dāng)于保持每個(gè)詞向量的方向不變,但對(duì)模長(zhǎng)標(biāo)準(zhǔn)化。)
5, Llama解碼層: LlamaDecoderLayer (同時(shí)具備信息融合,信息轉(zhuǎn)換功能的基本結(jié)構(gòu)單元)
6, Llama解碼器: LlamaModel (多個(gè)解碼層的堆疊)
7,Llama語(yǔ)言模型: LlamaForCausalLM (解碼器加上語(yǔ)言模型head,可用于文本生成)
8,Llama分類(lèi)模型: LlamaForSequenceClassification (解碼器加上分類(lèi)head,可用于文本分類(lèi))
?
?
import?math from?typing?import?List,?Optional,?Tuple,?Union import?torch import?torch.nn.functional?as?F import?torch.utils.checkpoint from?torch?import?nn from?torch.nn?import?BCEWithLogitsLoss,?CrossEntropyLoss,?MSELoss from?transformers.activations?import?ACT2FN from?transformers.modeling_outputs?import?BaseModelOutputWithPast,?CausalLMOutputWithPast,?SequenceClassifierOutputWithPast from?transformers.modeling_utils?import?PreTrainedModel from?transformers.utils?import?add_start_docstrings,?add_start_docstrings_to_model_forward,?logging,?replace_return_docstrings from?transformers.models.llama.configuration_llama??import?LlamaConfig from?transformers.models.llama.modeling_llama?import?LLAMA_INPUTS_DOCSTRING,LLAMA_START_DOCSTRING logger?=?logging.get_logger('llama') config?=?LlamaConfig( ????vocab_size=len(vocab), ????hidden_size=512, ????intermediate_size=2752, ????num_hidden_layers=8, ????num_attention_heads=16, ????hidden_act='silu', ????max_position_embeddings=128, ????initializer_range=0.02, ????rms_norm_eps=1e-06, ????use_cache=True, ????pad_token_id=0, ????bos_token_id=1, ????eos_token_id=2, ????tie_word_embeddings=False )?
?
?
1,旋轉(zhuǎn)位置編碼 RoPE
旋轉(zhuǎn)位置編碼即使用旋轉(zhuǎn)矩陣表示位置編碼(Rotary Position Encoding),簡(jiǎn)稱RoPE。
關(guān)于RoPE的3個(gè)核心要點(diǎn)知識(shí)如下:
RoPE的設(shè)計(jì)思想是使用絕對(duì)位置編碼來(lái)達(dá)到相對(duì)位置編碼的效果。
RoPE的實(shí)現(xiàn)方式是使用旋轉(zhuǎn)矩陣來(lái)表示絕對(duì)位置編碼。
使用NTK擴(kuò)展方法可以讓RoPE在短文本上訓(xùn)練并在長(zhǎng)文本上做預(yù)測(cè)。
參考文章:
《博采眾長(zhǎng)的旋轉(zhuǎn)式位置編碼》https://kexue.fm/archives/8265
《RoPE是一種進(jìn)制編碼》https://kexue.fm/archives/9675
(1)絕對(duì)位置編碼和相對(duì)位置編碼
位置編碼一般可以分成絕對(duì)位置編碼和相對(duì)位置編碼。
絕對(duì)位置編碼的優(yōu)點(diǎn)是計(jì)算簡(jiǎn)單高效,缺點(diǎn)是一般效果不如相對(duì)位置編碼。
相對(duì)位置編碼的優(yōu)點(diǎn)是效果較好,缺點(diǎn)是計(jì)算效率不如絕對(duì)位置編碼。
絕對(duì)位置編碼:
相對(duì)位置編碼:
在相對(duì)位置編碼中,注意力權(quán)重的結(jié)果僅僅和參與注意力計(jì)算的token向量的相對(duì)位置有關(guān),不和絕對(duì)位置直接關(guān)聯(lián)。
這符合NLP領(lǐng)域在序列長(zhǎng)度方向上具有平移不變性的特點(diǎn),所以相對(duì)位置編碼一般效果會(huì)優(yōu)于絕對(duì)位置編碼。
不過(guò)絕對(duì)位置編碼并非一無(wú)是處,絕對(duì)位置編碼只需要初始化時(shí)對(duì)序列的每個(gè)位置(數(shù)量正比于序列長(zhǎng)度)賦予位置編碼即可,后續(xù)無(wú)需干預(yù)。
而相對(duì)位置編碼要在計(jì)算過(guò)程中獲取許多個(gè)(數(shù)量正比于序列長(zhǎng)度平方)相對(duì)位置。
因此絕對(duì)位置編碼更加簡(jiǎn)單高效。
(2)使用旋轉(zhuǎn)矩陣表示位置編碼
上述討論可以看到,絕對(duì)位置編碼和相對(duì)位置編碼互有優(yōu)劣,那么有沒(méi)有什么辦法能夠?qū)Χ哌M(jìn)行取長(zhǎng)補(bǔ)短呢?
有的,這個(gè)方法就是RoPE,它的設(shè)計(jì)思想就是使用絕對(duì)位置編碼來(lái)達(dá)到相對(duì)位置編碼的效果。
那么旋轉(zhuǎn)位置編碼如何使用絕對(duì)位置編碼來(lái)達(dá)到相對(duì)位置編碼的效果的呢?答案是使用旋轉(zhuǎn)矩陣來(lái)表示位置編碼。
其中 為旋轉(zhuǎn)矩陣,滿足性質(zhì) 。于是,有:
符合 相對(duì)位置編碼形式。
perfect! 我們用絕對(duì)位置編碼實(shí)現(xiàn)了相對(duì)位置編碼的效果。
那么,旋轉(zhuǎn)矩陣長(zhǎng)什么樣呢?
在二維情形長(zhǎng)下面樣子。
在NLP領(lǐng)域,詞向量的維度一般會(huì)很高(例如4096)。
利用矩陣的分塊思想,可以證明高維情形下擴(kuò)展成下述形式依舊滿足旋轉(zhuǎn)矩陣性質(zhì)
其中 ,即越高的維度對(duì)應(yīng)三角函數(shù)的系數(shù)越小,周期越大,變化越緩慢。
由于旋轉(zhuǎn)矩陣是稀疏矩陣,直接使用乘法計(jì)算會(huì)很浪費(fèi)算力,可以將旋轉(zhuǎn)位置編碼過(guò)程由矩陣乘法運(yùn)算簡(jiǎn)化成兩次向量的哈達(dá)瑪積求和。
(3)旋轉(zhuǎn)位置編碼的長(zhǎng)度擴(kuò)展
在LLM的應(yīng)用中,有一個(gè)非常重要的參數(shù),叫做LLM支持的上下文長(zhǎng)度(max context length)。
更長(zhǎng)的上下文長(zhǎng)度允許我們進(jìn)行更多輪次的對(duì)話,允許我們對(duì)更長(zhǎng)的本文進(jìn)行總結(jié)分析,也允許我們生成更長(zhǎng)的文章。
但是在訓(xùn)練LLM的時(shí)候,我們的訓(xùn)練語(yǔ)料大部分是不夠長(zhǎng)的,許多LLM訓(xùn)練時(shí)候設(shè)計(jì)的最大文本長(zhǎng)度都是只有2k,也就是最長(zhǎng)2048個(gè)token。
那么,能否在訓(xùn)練的時(shí)候使用較短的文本,而在推理的時(shí)候擴(kuò)展到長(zhǎng)文本上呢?
是有可能的,我們可以對(duì)RoPE進(jìn)行長(zhǎng)度擴(kuò)展。
我們介紹3種擴(kuò)展方案。
第一種是直接外推:直接外推其實(shí)就是繼續(xù)沿用現(xiàn)有的位置編碼公式,不做任何修改。
在擴(kuò)展長(zhǎng)度不太長(zhǎng)的時(shí)候,例如由2k擴(kuò)展到2.5k時(shí),這種方法可能對(duì)性能的影響并不大。
因?yàn)樾D(zhuǎn)位置編碼只和相對(duì)位置m-n的大小有關(guān),一般具有遠(yuǎn)程衰減性,即相對(duì)距離越大的兩個(gè)token,其相關(guān)性一般越弱。
因此如果我們的模型已經(jīng)從訓(xùn)練數(shù)據(jù)那里學(xué)習(xí)到了token之間的相關(guān)性相對(duì)于相對(duì)距離在0-2k的一個(gè)合適的衰減規(guī)律的時(shí)候,可以設(shè)想把這個(gè)規(guī)律應(yīng)用到0-2.5k也是沒(méi)有太大的問(wèn)題的。
但是如果我們要擴(kuò)展到更長(zhǎng)的長(zhǎng)度,例如從2k擴(kuò)展到32k,這種直接外推的方案通常會(huì)嚴(yán)重地影響性能。因?yàn)槲覀儗W(xué)習(xí)到的衰減規(guī)律有可能在5k的那里就完全衰減截?cái)嗷窘禐?了,這樣我們就無(wú)法捕捉相對(duì)距離長(zhǎng)于5k的兩個(gè)token之間的相互作用,外推就會(huì)導(dǎo)致性能下降。
總結(jié)一下,直接外推對(duì)衰減規(guī)律在長(zhǎng)距離情況下的使用容易出現(xiàn)問(wèn)題,導(dǎo)致性能下降。
為了減少長(zhǎng)度外推對(duì)性能的影響,我們可以讓訓(xùn)練好的模型在更長(zhǎng)的上下文上做少許步驟的微調(diào)。
第二種是線性內(nèi)插:線性內(nèi)插需要改變位置編碼公式,等效于將位置序號(hào)等比例縮小。
編碼公式變化如 ,當(dāng)從2k擴(kuò)展到32k,等效于需要將位置序號(hào)變成原來(lái)的1/16.
線性內(nèi)插沒(méi)有改變模型學(xué)習(xí)到的衰減規(guī)律的應(yīng)用范圍,不考慮微調(diào)的話,其效果一般好于直接外推方案。
但是,擴(kuò)展倍數(shù)非常大的時(shí)候,例如從2k擴(kuò)展到32k,其性能也會(huì)明顯的受到影響。
因?yàn)樵谶@種情況下,衰減規(guī)律在短距離情況下的使用會(huì)受到較嚴(yán)重的影響,本來(lái)距離為1的兩個(gè)token,長(zhǎng)度擴(kuò)展后相當(dāng)于變成了距離為1/16,衰減規(guī)律在短距離時(shí)可能具有非常大的變化率,因此對(duì)相關(guān)性的評(píng)估可能會(huì)極端地偏離合理值。
應(yīng)用線性內(nèi)插時(shí),在長(zhǎng)文本上做少許步驟的微調(diào)也能夠明顯地改善性能。
第三種是NTK擴(kuò)展方式:這種方式綜合了外推和內(nèi)插的優(yōu)點(diǎn),做長(zhǎng)度擴(kuò)展后即使不微調(diào)也能夠保持較好的性能。
前面的分析我們知道直接外推對(duì)衰減規(guī)律在長(zhǎng)距離情況下的使用容易出問(wèn)題,在短距離情況下的使用不受影響。
而線性內(nèi)插對(duì)衰減規(guī)律在短距離情況下的使用容易出現(xiàn)問(wèn)題,在長(zhǎng)距離的情況下影響較小。
我們能否將它們綜合起來(lái),在短距離情況下具有外推特性(與擴(kuò)展前基本一致),在長(zhǎng)距離情況下具有內(nèi)插特性(縮放到擴(kuò)展前的范圍),從而使得長(zhǎng)距離情況下和短距離情況下衰減規(guī)律的使用都不太受到影響呢。
我們觀察RoPE位置編碼第行的元素計(jì)算公式,可以發(fā)現(xiàn)越大,三角函數(shù)對(duì)應(yīng)的角頻率系數(shù)越小,或者說(shuō)越低頻,對(duì)應(yīng)的三角函數(shù)變化越慢。
容易得到如下直觀結(jié)論:短距離之間的差異(例如1和5的差異),主要體現(xiàn)在高頻分量(i比較小)上,長(zhǎng)距離之間的差異(例如5000和10000的差異),主要體現(xiàn)在低頻分量(i比較大)上。
為了在短距離情況下具有外推特性,而在長(zhǎng)距離情況下具有內(nèi)插特性,我們可以設(shè)計(jì)一個(gè)和有關(guān)的位置序號(hào)縮放因子,使得在最高頻()時(shí)取值為1(與擴(kuò)展前基本一致),而在最低頻時(shí)()恰好為縮放倍數(shù)的倒數(shù)(縮放到擴(kuò)展前的范圍)。
一種有效的選擇方案是的指數(shù)函數(shù),其效果相當(dāng)于對(duì)中的做一個(gè)縮放,根據(jù)邊界條件容易求得合適的縮放因子為 。
NTK擴(kuò)展方式的要點(diǎn)是高頻外推,低頻內(nèi)插,實(shí)現(xiàn)方法是直接對(duì)底數(shù)base進(jìn)行縮放,類(lèi)似進(jìn)制編碼轉(zhuǎn)換。
采用NTK擴(kuò)展到長(zhǎng)文本,即使不做微調(diào),性能會(huì)只會(huì)略有下降。
下面是RoPE以及三種長(zhǎng)度擴(kuò)展方式的實(shí)現(xiàn)。
?
?
class?LlamaRotaryEmbedding(torch.nn.Module): ????def?__init__(self,?dim,?max_position_embeddings=2048,?base=10000,?device=None): ????????super().__init__() ????????self.dim?=?dim ????????self.max_position_embeddings?=?max_position_embeddings ????????self.base?=?base ????????inv_freq?=?1.0?/?(self.base?**?(torch.arange(0,?self.dim,?2).float().to(device)?/?self.dim)) ????????self.register_buffer("inv_freq",?inv_freq,?persistent=False)?#persistent=False將不會(huì)作為state_dict ????????#?Build?here?to?make?`torch.jit.trace`?work. ????????self._set_cos_sin_cache( ????????????seq_len=max_position_embeddings,?device=self.inv_freq.device,?dtype=torch.get_default_dtype() ????????) ????def?_set_cos_sin_cache(self,?seq_len,?device,?dtype): ????????self.max_seq_len_cached?=?seq_len ????????t?=?torch.arange(self.max_seq_len_cached,?device=device,?dtype=self.inv_freq.dtype) ????????freqs?=?torch.einsum("i,j->ij",?t,?self.inv_freq) ????????#?Different?from?paper,?but?it?uses?a?different?permutation?in?order?to?obtain?the?same?calculation ????????emb?=?torch.cat((freqs,?freqs),?dim=-1) ????????self.register_buffer("cos_cached",?emb.cos()[None,?None,?:,?:].to(dtype),?persistent=False) ????????self.register_buffer("sin_cached",?emb.sin()[None,?None,?:,?:].to(dtype),?persistent=False) ????def?forward(self,?x,?seq_len=None): ????????#?x:?[bs,?num_attention_heads,?seq_len,?head_size] ????????#超過(guò)預(yù)設(shè)的max_position_embeddings則重新計(jì)算更大的Rope緩存,否則直接在緩存上切片 ????????if?seq_len?>?self.max_seq_len_cached:? ????????????self._set_cos_sin_cache(seq_len=seq_len,?device=x.device,?dtype=x.dtype) ????????return?( ????????????self.cos_cached[:,?:,?:seq_len,?...].to(dtype=x.dtype), ????????????self.sin_cached[:,?:,?:seq_len,?...].to(dtype=x.dtype), ????????) ???? class?LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): ????"""LlamaRotaryEmbedding?extended?with?linear?scaling.?Credits?to?the?Reddit?user?/u/kaiokendev""" ????def?__init__(self,?dim,?max_position_embeddings=2048,?base=10000,?device=None,?scaling_factor=1.0): ????????self.scaling_factor?=?scaling_factor ????????super().__init__(dim,?max_position_embeddings,?base,?device) ????def?_set_cos_sin_cache(self,?seq_len,?device,?dtype): ????????self.max_seq_len_cached?=?seq_len ????????t?=?torch.arange(self.max_seq_len_cached,?device=device,?dtype=self.inv_freq.dtype) ????????t?=?t?/?self.scaling_factor?#線性內(nèi)插相當(dāng)于將位置序號(hào)等比例縮小 ????????freqs?=?torch.einsum("i,j->ij",?t,?self.inv_freq) ????????#?Different?from?paper,?but?it?uses?a?different?permutation?in?order?to?obtain?the?same?calculation ????????emb?=?torch.cat((freqs,?freqs),?dim=-1) ????????self.register_buffer("cos_cached",?emb.cos()[None,?None,?:,?:].to(dtype),?persistent=False) ????????self.register_buffer("sin_cached",?emb.sin()[None,?None,?:,?:].to(dtype),?persistent=False) class?LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): ????"""LlamaRotaryEmbedding?extended?with?Dynamic?NTK?scaling.?Credits?to?the?Reddit?users?/u/bloc97?and?/u/emozilla""" ????def?__init__(self,?dim,?max_position_embeddings=2048,?base=10000,?device=None,?scaling_factor=1.0): ????????self.scaling_factor?=?scaling_factor ????????super().__init__(dim,?max_position_embeddings,?base,?device) ????def?_set_cos_sin_cache(self,?seq_len,?device,?dtype): ????????self.max_seq_len_cached?=?seq_len ????????if?seq_len?>?self.max_position_embeddings: ????????????base?=?self.base?*?( ????????????????(self.scaling_factor?*?seq_len?/?self.max_position_embeddings)?-?(self.scaling_factor?-?1) ????????????)?**?(self.dim?/?(self.dim?-?2))??#NTK擴(kuò)展方式直接對(duì)base進(jìn)行縮放 ????????????inv_freq?=?1.0?/?(base?**?(torch.arange(0,?self.dim,?2).float().to(device)?/?self.dim)) ????????????self.register_buffer("inv_freq",?inv_freq,?persistent=False) ????????t?=?torch.arange(self.max_seq_len_cached,?device=device,?dtype=self.inv_freq.dtype) ????????freqs?=?torch.einsum("i,j->ij",?t,?self.inv_freq) ???????? ????????#此處處理邏輯與原始的ROPE有差異,原始邏輯如下 ????????#emb?=?torch.cat((freqs,?freqs),?dim=-1) ????????#emb[...,0::2]=freqs ????????#emb[...,1::2]=freqs ???????? ???????? ????????#?Different?from?paper,?but?it?uses?a?different?permutation?in?order?to?obtain?the?same?calculation ????????emb?=?torch.cat((freqs,?freqs),?dim=-1) ????????self.register_buffer("cos_cached",?emb.cos()[None,?None,?:,?:].to(dtype),?persistent=False) ????????self.register_buffer("sin_cached",?emb.sin()[None,?None,?:,?:].to(dtype),?persistent=False) ???????? ???????? def?rotate_half(x): ????"""Rotates?half?the?hidden?dims?of?the?input.""" ???? ????#此處邏輯與原始的ROPE有所差異,原始邏輯如下 ????#x1?=?x[...,?0::2]? ????#x2?=?x[...,?1::2] ????#res?=?torch.cat((x1,?x2),?dim=-1) ????#res[...,0::2]=-x2 ????#res[...,1::2]=x1 ????#return?res ???? ????x1?=?x[...,?:?x.shape[-1]?//?2]? ????x2?=?x[...,?x.shape[-1]?//?2?:] ????return?torch.cat((-x2,?x1),?dim=-1) def?apply_rotary_pos_emb(q,?k,?cos,?sin,?position_ids): ????#?The?first?two?dimensions?of?cos?and?sin?are?always?1,?so?we?can?`squeeze`?them. ????cos?=?cos.squeeze(1).squeeze(0)??#?[seq_len,?dim] ????sin?=?sin.squeeze(1).squeeze(0)??#?[seq_len,?dim] ????cos?=?cos[position_ids].unsqueeze(1)??#?[bs,?1,?seq_len,?dim] ????sin?=?sin[position_ids].unsqueeze(1)??#?[bs,?1,?seq_len,?dim] ????q_embed?=?(q?*?cos)?+?(rotate_half(q)?*?sin) ????k_embed?=?(k?*?cos)?+?(rotate_half(k)?*?sin) ????return?q_embed,?k_embed
x?=?torch.randn(1,8,4,2) rope?=?LlamaRotaryEmbedding(dim=8) cos,sin?=?rope.forward(x,seq_len=4) print(cos.shape)? print(cos)
torch.Size([1, 1, 4, 8]) tensor([[[[ 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [ 0.5403, 0.9950, 0.9999, 1.0000, 0.5403, 0.9950, 0.9999, 1.0000], [-0.4161, 0.9801, 0.9998, 1.0000, -0.4161, 0.9801, 0.9998, 1.0000], [-0.9900, 0.9553, 0.9996, 1.0000, -0.9900, 0.9553, 0.9996, 1.0000]]]])
?
?
2,多頭注意力 LlamaAttention
這里的LlamaAttention 基本上和《Attention Is All You Need》論文里的是一致的,主要差異有以下一些。
1,k和v的head數(shù)量可以是q的head數(shù)量的幾分之一,類(lèi)似分組卷積的思想,可以減少參數(shù)規(guī)模。
2,rope位置編碼是每次做多頭注意力時(shí)都進(jìn)行一次,而不是原論文只在輸入的時(shí)候進(jìn)行一次。
3,允許傳入key和value的states的緩存past_key_value,這在多輪對(duì)話中可以減少重復(fù)計(jì)算,起到加速效果。
4,attention_mask是通過(guò)加法形式作用到softmax之前的attention矩陣上的。
?
?
def?repeat_kv(hidden_states:?torch.Tensor,?n_rep:?int)?->?torch.Tensor: ????""" ????This?is?the?equivalent?of?torch.repeat_interleave(x,?dim=1,?repeats=n_rep).?The?hidden?states?go?from?(batch, ????num_key_value_heads,?seqlen,?head_dim)?to?(batch,?num_attention_heads,?seqlen,?head_dim) ????""" ????batch,?num_key_value_heads,?slen,?head_dim?=?hidden_states.shape ????if?n_rep?==?1: ????????return?hidden_states ????hidden_states?=?hidden_states[:,?:,?None,?:,?:].expand(batch,?num_key_value_heads,?n_rep,?slen,?head_dim) ????return?hidden_states.reshape(batch,?num_key_value_heads?*?n_rep,?slen,?head_dim) class?LlamaAttention(nn.Module): ????"""Multi-headed?attention?from?'Attention?Is?All?You?Need'?paper""" ????def?__init__(self,?config:?LlamaConfig): ????????super().__init__() ????????self.config?=?config ????????self.hidden_size?=?config.hidden_size ????????self.num_heads?=?config.num_attention_heads ????????self.head_dim?=?self.hidden_size?//?self.num_heads ????????self.num_key_value_heads?=?config.num_key_value_heads ????????self.num_key_value_groups?=?self.num_heads?//?self.num_key_value_heads ????????self.max_position_embeddings?=?config.max_position_embeddings ????????if?(self.head_dim?*?self.num_heads)?!=?self.hidden_size: ????????????raise?ValueError( ????????????????f"hidden_size?must?be?divisible?by?num_heads?(got?`hidden_size`:?{self.hidden_size}" ????????????????f"?and?`num_heads`:?{self.num_heads})." ????????????) ????????self.q_proj?=?nn.Linear(self.hidden_size,?self.num_heads?*?self.head_dim,?bias=False) ????????self.k_proj?=?nn.Linear(self.hidden_size,?self.num_key_value_heads?*?self.head_dim,?bias=False) ????????self.v_proj?=?nn.Linear(self.hidden_size,?self.num_key_value_heads?*?self.head_dim,?bias=False) ????????self.o_proj?=?nn.Linear(self.num_heads?*?self.head_dim,?self.hidden_size,?bias=False) ????????self._init_rope() ????def?_init_rope(self): ????????if?self.config.rope_scaling?is?None: ????????????self.rotary_emb?=?LlamaRotaryEmbedding(self.head_dim,?max_position_embeddings=self.max_position_embeddings) ????????else: ????????????scaling_type?=?self.config.rope_scaling["type"] ????????????scaling_factor?=?self.config.rope_scaling["factor"] ????????????if?scaling_type?==?"linear": ????????????????self.rotary_emb?=?LlamaLinearScalingRotaryEmbedding( ????????????????????self.head_dim,?max_position_embeddings=self.max_position_embeddings,?scaling_factor=scaling_factor ????????????????) ????????????elif?scaling_type?==?"dynamic": ????????????????self.rotary_emb?=?LlamaDynamicNTKScalingRotaryEmbedding( ????????????????????self.head_dim,?max_position_embeddings=self.max_position_embeddings,?scaling_factor=scaling_factor ????????????????) ????????????else: ????????????????raise?ValueError(f"Unknown?RoPE?scaling?type?{scaling_type}") ????def?_shape(self,?tensor:?torch.Tensor,?seq_len:?int,?bsz:?int): ????????return?tensor.view(bsz,?seq_len,?self.num_heads,?self.head_dim).transpose(1,?2).contiguous() ????def?forward( ????????self, ????????hidden_states:?torch.Tensor, ????????attention_mask:?Optional[torch.Tensor]?=?None, ????????position_ids:?Optional[torch.LongTensor]?=?None, ????????past_key_value:?Optional[Tuple[torch.Tensor]]?=?None, ????????output_attentions:?bool?=?False, ????????use_cache:?bool?=?False, ????)?->?Tuple[torch.Tensor,?Optional[torch.Tensor],?Optional[Tuple[torch.Tensor]]]: ????????bsz,?q_len,?_?=?hidden_states.size() ????????if?self.config.pretraining_tp?>?1: ????????????key_value_slicing?=?(self.num_key_value_heads?*?self.head_dim)?//?self.config.pretraining_tp ????????????query_slices?=?self.q_proj.weight.split( ????????????????(self.num_heads?*?self.head_dim)?//?self.config.pretraining_tp,?dim=0 ????????????) ????????????key_slices?=?self.k_proj.weight.split(key_value_slicing,?dim=0) ????????????value_slices?=?self.v_proj.weight.split(key_value_slicing,?dim=0) ????????????query_states?=?[F.linear(hidden_states,?query_slices[i])?for?i?in?range(self.config.pretraining_tp)] ????????????query_states?=?torch.cat(query_states,?dim=-1) ????????????key_states?=?[F.linear(hidden_states,?key_slices[i])?for?i?in?range(self.config.pretraining_tp)] ????????????key_states?=?torch.cat(key_states,?dim=-1) ????????????value_states?=?[F.linear(hidden_states,?value_slices[i])?for?i?in?range(self.config.pretraining_tp)] ????????????value_states?=?torch.cat(value_states,?dim=-1) ????????else: ????????????query_states?=?self.q_proj(hidden_states) ????????????key_states?=?self.k_proj(hidden_states) ????????????value_states?=?self.v_proj(hidden_states) ????????query_states?=?query_states.view(bsz,?q_len,?self.num_heads,?self.head_dim).transpose(1,?2) ????????key_states?=?key_states.view(bsz,?q_len,?self.num_key_value_heads,?self.head_dim).transpose(1,?2) ????????value_states?=?value_states.view(bsz,?q_len,?self.num_key_value_heads,?self.head_dim).transpose(1,?2) ????????kv_seq_len?=?key_states.shape[-2] ????????if?past_key_value?is?not?None: ????????????kv_seq_len?+=?past_key_value[0].shape[-2] ????????cos,?sin?=?self.rotary_emb(value_states,?seq_len=kv_seq_len) ????????query_states,?key_states?=?apply_rotary_pos_emb(query_states,?key_states,?cos,?sin,?position_ids) ????????if?past_key_value?is?not?None: ????????????#?reuse?k,?v,?self_attention ????????????key_states?=?torch.cat([past_key_value[0],?key_states],?dim=2) ????????????value_states?=?torch.cat([past_key_value[1],?value_states],?dim=2) ????????past_key_value?=?(key_states,?value_states)?if?use_cache?else?None ????????#?repeat?k/v?heads?if?n_kv_heads??1: ????????????attn_output?=?attn_output.split(self.hidden_size?//?self.config.pretraining_tp,?dim=2) ????????????o_proj_slices?=?self.o_proj.weight.split(self.hidden_size?//?self.config.pretraining_tp,?dim=1) ????????????attn_output?=?sum([F.linear(attn_output[i],?o_proj_slices[i])?for?i?in?range(self.config.pretraining_tp)]) ????????else: ????????????attn_output?=?self.o_proj(attn_output) ????????if?not?output_attentions: ????????????attn_weights?=?None ????????return?attn_output,?attn_weights,?past_key_value ???? ????
?
?
3,前饋網(wǎng)絡(luò) LlamaMLP
前饋網(wǎng)絡(luò)是一個(gè)2層的感知機(jī)MLP。
先從hidden_size維度up_proj到intermediate_size維度,然后再down_proj還原為hidden_size維度。
這里的主要特色是引入了一個(gè)gate_proj配合激活函數(shù)來(lái)實(shí)現(xiàn)一個(gè)門(mén)控注意力的作用。
?
?
class?LlamaMLP(nn.Module): ????def?__init__(self,?config): ????????super().__init__() ????????self.config?=?config ????????self.hidden_size?=?config.hidden_size ????????self.intermediate_size?=?config.intermediate_size ????????self.gate_proj?=?nn.Linear(self.hidden_size,?self.intermediate_size,?bias=False) ????????self.up_proj?=?nn.Linear(self.hidden_size,?self.intermediate_size,?bias=False) ????????self.down_proj?=?nn.Linear(self.intermediate_size,?self.hidden_size,?bias=False) ????????self.act_fn?=?ACT2FN[config.hidden_act] ????def?forward(self,?x): ????????if?self.config.pretraining_tp?>?1: ????????????slice?=?self.intermediate_size?//?self.config.pretraining_tp ????????????gate_proj_slices?=?self.gate_proj.weight.split(slice,?dim=0) ????????????up_proj_slices?=?self.up_proj.weight.split(slice,?dim=0) ????????????down_proj_slices?=?self.down_proj.weight.split(slice,?dim=1) ????????????gate_proj?=?torch.cat( ????????????????[F.linear(x,?gate_proj_slices[i])?for?i?in?range(self.config.pretraining_tp)],?dim=-1 ????????????) ????????????up_proj?=?torch.cat([F.linear(x,?up_proj_slices[i])?for?i?in?range(self.config.pretraining_tp)],?dim=-1) ????????????intermediate_states?=?(self.act_fn(gate_proj)?*?up_proj).split(slice,?dim=2) ????????????down_proj?=?[ ????????????????F.linear(intermediate_states[i],?down_proj_slices[i])?for?i?in?range(self.config.pretraining_tp) ????????????] ????????????down_proj?=?sum(down_proj) ????????else: ????????????down_proj?=?self.down_proj(self.act_fn(self.gate_proj(x))?*?self.up_proj(x)) ????????return?down_proj
?
?
4,層歸一化 LlamaRMSNorm
這里的層歸一化叫做RMSNorm,和標(biāo)準(zhǔn)的LayerNorm有少許差異。
首先是沒(méi)有移除均值,直接除的RootMeanSquare,然后也沒(méi)有加上bias。
這兩個(gè)小的修正可以保證在層歸一化不會(huì)改變hidden_states對(duì)應(yīng)的詞向量的方向,只會(huì)改變其模長(zhǎng)。
在一定的意義上具有合理性。
?
?
class?LlamaRMSNorm(nn.Module): ????def?__init__(self,?hidden_size,?eps=1e-6): ????????""" ????????LlamaRMSNorm?is?equivalent?to?T5LayerNorm ????????""" ????????super().__init__() ????????self.weight?=?nn.Parameter(torch.ones(hidden_size)) ????????self.variance_epsilon?=?eps ????def?forward(self,?hidden_states): ????????input_dtype?=?hidden_states.dtype ????????hidden_states?=?hidden_states.to(torch.float32) ????????variance?=?hidden_states.pow(2).mean(-1,?keepdim=True) ????????hidden_states?=?hidden_states?*?torch.rsqrt(variance?+?self.variance_epsilon) ????????return?self.weight?*?hidden_states.to(input_dtype) ????
?
?
5,Llama解碼層
解碼層LlamaDecoderLayer由LlamaAttention,LlamaMLP,以及兩個(gè)LlamaRMSNorm組成,并使用了兩次殘差結(jié)構(gòu)。
?
?
class?LlamaDecoderLayer(nn.Module): ????def?__init__(self,?config:?LlamaConfig): ????????super().__init__() ????????self.hidden_size?=?config.hidden_size ????????self.self_attn?=?LlamaAttention(config=config) ????????self.mlp?=?LlamaMLP(config) ????????self.input_layernorm?=?LlamaRMSNorm(config.hidden_size,?eps=config.rms_norm_eps) ????????self.post_attention_layernorm?=?LlamaRMSNorm(config.hidden_size,?eps=config.rms_norm_eps) ????def?forward( ????????self, ????????hidden_states:?torch.Tensor, ????????attention_mask:?Optional[torch.Tensor]?=?None, ????????position_ids:?Optional[torch.LongTensor]?=?None, ????????past_key_value:?Optional[Tuple[torch.Tensor]]?=?None, ????????output_attentions:?Optional[bool]?=?False, ????????use_cache:?Optional[bool]?=?False, ????)?->?Tuple[torch.FloatTensor,?Optional[Tuple[torch.FloatTensor,?torch.FloatTensor]]]: ????????""" ????????Args: ????????????hidden_states?(`torch.FloatTensor`):?input?to?the?layer?of?shape?`(batch,?seq_len,?embed_dim)` ????????????attention_mask?(`torch.FloatTensor`,?*optional*):?attention?mask?of?size ????????????????`(batch,?1,?tgt_len,?src_len)`?where?padding?elements?are?indicated?by?very?large?negative?values. ????????????output_attentions?(`bool`,?*optional*): ????????????????Whether?or?not?to?return?the?attentions?tensors?of?all?attention?layers.?See?`attentions`?under ????????????????returned?tensors?for?more?detail. ????????????use_cache?(`bool`,?*optional*): ????????????????If?set?to?`True`,?`past_key_values`?key?value?states?are?returned?and?can?be?used?to?speed?up?decoding ????????????????(see?`past_key_values`). ????????????past_key_value?(`Tuple(torch.FloatTensor)`,?*optional*):?cached?past?key?and?value?projection?states ????????""" ????????residual?=?hidden_states ????????hidden_states?=?self.input_layernorm(hidden_states) ????????#?Self?Attention ????????hidden_states,?self_attn_weights,?present_key_value?=?self.self_attn( ????????????hidden_states=hidden_states, ????????????attention_mask=attention_mask, ????????????position_ids=position_ids, ????????????past_key_value=past_key_value, ????????????output_attentions=output_attentions, ????????????use_cache=use_cache, ????????) ????????hidden_states?=?residual?+?hidden_states ????????#?Fully?Connected ????????residual?=?hidden_states ????????hidden_states?=?self.post_attention_layernorm(hidden_states) ????????hidden_states?=?self.mlp(hidden_states) ????????hidden_states?=?residual?+?hidden_states ????????outputs?=?(hidden_states,) ????????if?output_attentions: ????????????outputs?+=?(self_attn_weights,) ????????if?use_cache: ????????????outputs?+=?(present_key_value,) ????????return?outputs
?
?
6,Llama解碼器
LlamaModel由多個(gè)Llama解碼層堆疊而成。
有幾個(gè)理解上的要點(diǎn):
1,_make_causal_mask用于構(gòu)造下三角這種mask結(jié)構(gòu)以實(shí)現(xiàn)語(yǔ)言模型的單向注意力。
2,_expand_mask用于將傳入的等特殊符號(hào)相關(guān)的mask信息展開(kāi)成和attention矩陣相同的張量結(jié)構(gòu)。
3,設(shè)置gradient_checkpointing=True可以節(jié)約顯存。其主要應(yīng)用了torch.utils.checkpoint.checkpoint方法。它的原理非常簡(jiǎn)單,在對(duì)decoder_layer進(jìn)行forward時(shí)不保存中間激活值從而節(jié)約顯存,backward時(shí)重新計(jì)算相關(guān)值,從而通過(guò)時(shí)間換取了空間。
4,gradient_checkpointing和use_cache不能同時(shí)設(shè)置為T(mén)rue,前者是為了節(jié)約顯存時(shí)間換空間的,后者是為了節(jié)約時(shí)間空間換時(shí)間。
?
?
#?Copied?from?transformers.models.bart.modeling_bart._make_causal_mask def?_make_causal_mask( ????input_ids_shape:?torch.Size,?dtype:?torch.dtype,? ????device:?torch.device,?past_key_values_length:?int?=?0 ): ????""" ????Make?causal?mask?used?for?bi-directional?self-attention. ????""" ????bsz,?tgt_len?=?input_ids_shape ????mask?=?torch.full((tgt_len,?tgt_len),?torch.finfo(dtype).min,?device=device) ????mask_cond?=?torch.arange(mask.size(-1),?device=device) ????mask.masked_fill_(mask_cond?(mask_cond?+?1).view(mask.size(-1),?1),?0) ????mask?=?mask.to(dtype) ????if?past_key_values_length?>?0: ????????mask?=?torch.cat([torch.zeros(tgt_len,?past_key_values_length,?dtype=dtype,?device=device),?mask],?dim=-1) ????return?mask[None,?None,?:,?:].expand(bsz,?1,?tgt_len,?tgt_len?+?past_key_values_length) #?Copied?from?transformers.models.bart.modeling_bart._expand_mask def?_expand_mask(mask:?torch.Tensor,?dtype:?torch.dtype,?tgt_len:?Optional[int]?=?None): ????""" ????Expands?attention_mask?from?`[bsz,?seq_len]`?to?`[bsz,?1,?tgt_seq_len,?src_seq_len]`. ????""" ????bsz,?src_len?=?mask.size() ????tgt_len?=?tgt_len?if?tgt_len?is?not?None?else?src_len ????expanded_mask?=?mask[:,?None,?None,?:].expand(bsz,?1,?tgt_len,?src_len).to(dtype) ????inverted_mask?=?1.0?-?expanded_mask ????return?inverted_mask.masked_fill(inverted_mask.to(torch.bool),?torch.finfo(dtype).min) @add_start_docstrings( ????"The?bare?LLaMA?Model?outputting?raw?hidden-states?without?any?specific?head?on?top.", ????LLAMA_START_DOCSTRING, ) class?LlamaPreTrainedModel(PreTrainedModel): ????config_class?=?LlamaConfig ????base_model_prefix?=?"model" ????supports_gradient_checkpointing?=?True ????_no_split_modules?=?["LlamaDecoderLayer"] ????_skip_keys_device_placement?=?"past_key_values" ????def?_init_weights(self,?module): ????????std?=?self.config.initializer_range ????????if?isinstance(module,?nn.Linear): ????????????module.weight.data.normal_(mean=0.0,?std=std) ????????????if?module.bias?is?not?None: ????????????????module.bias.data.zero_() ????????elif?isinstance(module,?nn.Embedding): ????????????module.weight.data.normal_(mean=0.0,?std=std) ????????????if?module.padding_idx?is?not?None: ????????????????module.weight.data[module.padding_idx].zero_() ????def?_set_gradient_checkpointing(self,?module,?value=False): ????????if?isinstance(module,?LlamaModel): ????????????module.gradient_checkpointing?=?value @add_start_docstrings( ????"The?bare?LLaMA?Model?outputting?raw?hidden-states?without?any?specific?head?on?top.", ????LLAMA_START_DOCSTRING, ) class?LlamaModel(LlamaPreTrainedModel): ????""" ????Transformer?decoder?consisting?of?*config.num_hidden_layers*?layers.?Each?layer?is?a?[`LlamaDecoderLayer`] ????Args: ????????config:?LlamaConfig ????""" ????def?__init__(self,?config:?LlamaConfig): ????????super().__init__(config) ????????self.padding_idx?=?config.pad_token_id ????????self.vocab_size?=?config.vocab_size ????????self.embed_tokens?=?nn.Embedding(config.vocab_size,?config.hidden_size,?self.padding_idx) ????????self.layers?=?nn.ModuleList([LlamaDecoderLayer(config)?for?_?in?range(config.num_hidden_layers)]) ????????self.norm?=?LlamaRMSNorm(config.hidden_size,?eps=config.rms_norm_eps) ????????self.gradient_checkpointing?=?False ????????#?Initialize?weights?and?apply?final?processing ????????self.post_init() ????def?get_input_embeddings(self): ????????return?self.embed_tokens ????def?set_input_embeddings(self,?value): ????????self.embed_tokens?=?value ????#?Copied?from?transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask ????def?_prepare_decoder_attention_mask(self,?attention_mask,?input_shape,?inputs_embeds,?past_key_values_length): ????????#?create?causal?mask ????????#?[bsz,?seq_len]?->?[bsz,?1,?tgt_seq_len,?src_seq_len] ????????combined_attention_mask?=?None ????????if?input_shape[-1]?>?1: ????????????combined_attention_mask?=?_make_causal_mask( ????????????????input_shape, ????????????????inputs_embeds.dtype, ????????????????device=inputs_embeds.device, ????????????????past_key_values_length=past_key_values_length, ????????????) ????????if?attention_mask?is?not?None: ????????????#?[bsz,?seq_len]?->?[bsz,?1,?tgt_seq_len,?src_seq_len] ????????????expanded_attn_mask?=?_expand_mask(attention_mask,?inputs_embeds.dtype,?tgt_len=input_shape[-1]).to( ????????????????inputs_embeds.device ????????????) ????????????combined_attention_mask?=?( ????????????????expanded_attn_mask?if?combined_attention_mask?is?None?else?expanded_attn_mask?+?combined_attention_mask ????????????) ????????return?combined_attention_mask ????@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) ????def?forward( ????????self, ????????input_ids:?torch.LongTensor?=?None, ????????attention_mask:?Optional[torch.Tensor]?=?None, ????????position_ids:?Optional[torch.LongTensor]?=?None, ????????past_key_values:?Optional[List[torch.FloatTensor]]?=?None, ????????inputs_embeds:?Optional[torch.FloatTensor]?=?None, ????????use_cache:?Optional[bool]?=?None, ????????output_attentions:?Optional[bool]?=?None, ????????output_hidden_states:?Optional[bool]?=?None, ????????return_dict:?Optional[bool]?=?None, ????)?->?Union[Tuple,?BaseModelOutputWithPast]: ????????output_attentions?=?output_attentions?if?output_attentions?is?not?None?else?self.config.output_attentions ????????output_hidden_states?=?( ????????????output_hidden_states?if?output_hidden_states?is?not?None?else?self.config.output_hidden_states ????????) ????????use_cache?=?use_cache?if?use_cache?is?not?None?else?self.config.use_cache ????????return_dict?=?return_dict?if?return_dict?is?not?None?else?self.config.use_return_dict ????????#?retrieve?input_ids?and?inputs_embeds ????????if?input_ids?is?not?None?and?inputs_embeds?is?not?None: ????????????raise?ValueError("You?cannot?specify?both?decoder_input_ids?and?decoder_inputs_embeds?at?the?same?time") ????????elif?input_ids?is?not?None: ????????????batch_size,?seq_length?=?input_ids.shape ????????elif?inputs_embeds?is?not?None: ????????????batch_size,?seq_length,?_?=?inputs_embeds.shape ????????else: ????????????raise?ValueError("You?have?to?specify?either?decoder_input_ids?or?decoder_inputs_embeds") ????????seq_length_with_past?=?seq_length ????????past_key_values_length?=?0 ????????if?past_key_values?is?not?None: ????????????past_key_values_length?=?past_key_values[0][0].shape[2] ????????????seq_length_with_past?=?seq_length_with_past?+?past_key_values_length ????????if?position_ids?is?None: ????????????device?=?input_ids.device?if?input_ids?is?not?None?else?inputs_embeds.device ????????????position_ids?=?torch.arange( ????????????????past_key_values_length,?seq_length?+?past_key_values_length,?dtype=torch.long,?device=device ????????????) ????????????position_ids?=?position_ids.unsqueeze(0).view(-1,?seq_length) ????????else: ????????????position_ids?=?position_ids.view(-1,?seq_length).long() ????????if?inputs_embeds?is?None: ????????????inputs_embeds?=?self.embed_tokens(input_ids) ????????#?embed?positions ????????if?attention_mask?is?None: ????????????attention_mask?=?torch.ones( ????????????????(batch_size,?seq_length_with_past),?dtype=torch.bool,?device=inputs_embeds.device ????????????) ????????attention_mask?=?self._prepare_decoder_attention_mask( ????????????attention_mask,?(batch_size,?seq_length),?inputs_embeds,?past_key_values_length ????????) ????????hidden_states?=?inputs_embeds ????????if?self.gradient_checkpointing?and?self.training: ????????????if?use_cache: ????????????????logger.warning_once( ????????????????????"`use_cache=True`?is?incompatible?with?gradient?checkpointing.?Setting?`use_cache=False`..." ????????????????) ????????????????use_cache?=?False ????????#?decoder?layers ????????all_hidden_states?=?()?if?output_hidden_states?else?None ????????all_self_attns?=?()?if?output_attentions?else?None ????????next_decoder_cache?=?()?if?use_cache?else?None ????????for?idx,?decoder_layer?in?enumerate(self.layers): ????????????if?output_hidden_states: ????????????????all_hidden_states?+=?(hidden_states,) ????????????past_key_value?=?past_key_values[idx]?if?past_key_values?is?not?None?else?None ????????????if?self.gradient_checkpointing?and?self.training: ????????????????def?create_custom_forward(module): ????????????????????def?custom_forward(*inputs): ????????????????????????#?None?for?past_key_value ????????????????????????return?module(*inputs,?output_attentions,?None) ????????????????????return?custom_forward ????????????????layer_outputs?=?torch.utils.checkpoint.checkpoint( ????????????????????create_custom_forward(decoder_layer), ????????????????????hidden_states, ????????????????????attention_mask, ????????????????????position_ids, ????????????????????None, ????????????????) ????????????else: ????????????????layer_outputs?=?decoder_layer( ????????????????????hidden_states, ????????????????????attention_mask=attention_mask, ????????????????????position_ids=position_ids, ????????????????????past_key_value=past_key_value, ????????????????????output_attentions=output_attentions, ????????????????????use_cache=use_cache, ????????????????) ????????????hidden_states?=?layer_outputs[0] ????????????if?use_cache: ????????????????next_decoder_cache?+=?(layer_outputs[2?if?output_attentions?else?1],) ????????????if?output_attentions: ????????????????all_self_attns?+=?(layer_outputs[1],) ????????hidden_states?=?self.norm(hidden_states) ????????#?add?hidden?states?from?the?last?decoder?layer ????????if?output_hidden_states: ????????????all_hidden_states?+=?(hidden_states,) ????????next_cache?=?next_decoder_cache?if?use_cache?else?None ????????if?not?return_dict: ????????????return?tuple(v?for?v?in?[hidden_states,?next_cache,?all_hidden_states,?all_self_attns]?if?v?is?not?None) ????????return?BaseModelOutputWithPast( ????????????last_hidden_state=hidden_states, ????????????past_key_values=next_cache, ????????????hidden_states=all_hidden_states, ????????????attentions=all_self_attns, ????????)
?
?
7,Llama語(yǔ)言模型
Llama語(yǔ)言模型 LlamaForCausalLM是在Llama解碼器LlamaModel的基礎(chǔ)上增加了一個(gè)lm_head作為Generator。
從而實(shí)現(xiàn)了一個(gè)完整的語(yǔ)言模型。
除此之外,Llama語(yǔ)言模型還實(shí)現(xiàn)了以下重要功能。
1,loss計(jì)算功能。當(dāng)forward方法中傳入labels時(shí),會(huì)自動(dòng)計(jì)算語(yǔ)言模型的交叉熵?fù)p失。注意labels中的-100會(huì)被忽略不參與計(jì)算。
2,文本生成generate方法。這個(gè)方法繼承自PreTrainedModel,可以設(shè)置model.generation_config.num_beams選擇束搜索的束寬度,默認(rèn)為1即貪心搜索。
?
?
_CONFIG_FOR_DOC?=?"LlamaConfig" class?LlamaForCausalLM(LlamaPreTrainedModel): ????_tied_weights_keys?=?["lm_head.weight"] ????def?__init__(self,?config): ????????super().__init__(config) ????????self.model?=?LlamaModel(config) ????????self.vocab_size?=?config.vocab_size ????????self.lm_head?=?nn.Linear(config.hidden_size,?config.vocab_size,?bias=False) ????????#?Initialize?weights?and?apply?final?processing ????????self.post_init() ????def?get_input_embeddings(self): ????????return?self.model.embed_tokens ????def?set_input_embeddings(self,?value): ????????self.model.embed_tokens?=?value ????def?get_output_embeddings(self): ????????return?self.lm_head ????def?set_output_embeddings(self,?new_embeddings): ????????self.lm_head?=?new_embeddings ????def?set_decoder(self,?decoder): ????????self.model?=?decoder ????def?get_decoder(self): ????????return?self.model ????@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) ????@replace_return_docstrings(output_type=CausalLMOutputWithPast,?config_class=_CONFIG_FOR_DOC) ????def?forward( ????????self, ????????input_ids:?torch.LongTensor?=?None, ????????attention_mask:?Optional[torch.Tensor]?=?None, ????????position_ids:?Optional[torch.LongTensor]?=?None, ????????past_key_values:?Optional[List[torch.FloatTensor]]?=?None, ????????inputs_embeds:?Optional[torch.FloatTensor]?=?None, ????????labels:?Optional[torch.LongTensor]?=?None, ????????use_cache:?Optional[bool]?=?None, ????????output_attentions:?Optional[bool]?=?None, ????????output_hidden_states:?Optional[bool]?=?None, ????????return_dict:?Optional[bool]?=?None, ????)?->?Union[Tuple,?CausalLMOutputWithPast]: ????????output_attentions?=?output_attentions?if?output_attentions?is?not?None?else?self.config.output_attentions ????????output_hidden_states?=?( ????????????output_hidden_states?if?output_hidden_states?is?not?None?else?self.config.output_hidden_states ????????) ????????return_dict?=?return_dict?if?return_dict?is?not?None?else?self.config.use_return_dict ????????#?decoder?outputs?consists?of?(dec_features,?layer_state,?dec_hidden,?dec_attn) ????????outputs?=?self.model( ????????????input_ids=input_ids, ????????????attention_mask=attention_mask, ????????????position_ids=position_ids, ????????????past_key_values=past_key_values, ????????????inputs_embeds=inputs_embeds, ????????????use_cache=use_cache, ????????????output_attentions=output_attentions, ????????????output_hidden_states=output_hidden_states, ????????????return_dict=return_dict, ????????) ????????hidden_states?=?outputs[0] ????????if?self.config.pretraining_tp?>?1: ????????????lm_head_slices?=?self.lm_head.weight.split(self.vocab_size?//?self.config.pretraining_tp,?dim=0) ????????????logits?=?[F.linear(hidden_states,?lm_head_slices[i])?for?i?in?range(self.config.pretraining_tp)] ????????????logits?=?torch.cat(logits,?dim=-1) ????????else: ????????????logits?=?self.lm_head(hidden_states) ????????logits?=?logits.float() ????????loss?=?None ????????if?labels?is?not?None: ????????????#?Shift?so?that?tokens??
?
8,Llama分類(lèi)模型
LlamaForSequenceClassification是一個(gè)序列分類(lèi)模型。
這個(gè)分類(lèi)模型可以用來(lái)訓(xùn)練RLHF流程中的Reward模型。
?
?
@add_start_docstrings( ????""" ????The?LLaMa?Model?transformer?with?a?sequence?classification?head?on?top?(linear?layer). ????[`LlamaForSequenceClassification`]?uses?the?last?token?in?order?to?do?the?classification,?as?other?causal?models ????(e.g.?GPT-2)?do. ????Since?it?does?classification?on?the?last?token,?it?requires?to?know?the?position?of?the?last?token.?If?a ????`pad_token_id`?is?defined?in?the?configuration,?it?finds?the?last?token?that?is?not?a?padding?token?in?each?row.?If ????no?`pad_token_id`?is?defined,?it?simply?takes?the?last?value?in?each?row?of?the?batch.?Since?it?cannot?guess?the ????padding?tokens?when?`inputs_embeds`?are?passed?instead?of?`input_ids`,?it?does?the?same?(take?the?last?value?in ????each?row?of?the?batch). ????""", ????LLAMA_START_DOCSTRING, ) class?LlamaForSequenceClassification(LlamaPreTrainedModel): ????def?__init__(self,?config): ????????super().__init__(config) ????????self.num_labels?=?config.num_labels ????????self.model?=?LlamaModel(config) ????????self.score?=?nn.Linear(config.hidden_size,?self.num_labels,?bias=False) ????????#?Initialize?weights?and?apply?final?processing ????????self.post_init() ????def?get_input_embeddings(self): ????????return?self.model.embed_tokens ????def?set_input_embeddings(self,?value): ????????self.model.embed_tokens?=?value ????@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) ????def?forward( ????????self, ????????input_ids:?torch.LongTensor?=?None, ????????attention_mask:?Optional[torch.Tensor]?=?None, ????????position_ids:?Optional[torch.LongTensor]?=?None, ????????past_key_values:?Optional[List[torch.FloatTensor]]?=?None, ????????inputs_embeds:?Optional[torch.FloatTensor]?=?None, ????????labels:?Optional[torch.LongTensor]?=?None, ????????use_cache:?Optional[bool]?=?None, ????????output_attentions:?Optional[bool]?=?None, ????????output_hidden_states:?Optional[bool]?=?None, ????????return_dict:?Optional[bool]?=?None, ????)?->?Union[Tuple,?SequenceClassifierOutputWithPast]: ????????r""" ????????labels?(`torch.LongTensor`?of?shape?`(batch_size,)`,?*optional*): ????????????Labels?for?computing?the?sequence?classification/regression?loss.?Indices?should?be?in?`[0,?..., ????????????config.num_labels?-?1]`.?If?`config.num_labels?==?1`?a?regression?loss?is?computed?(Mean-Square?loss),?If ????????????`config.num_labels?>?1`?a?classification?loss?is?computed?(Cross-Entropy). ????????""" ????????return_dict?=?return_dict?if?return_dict?is?not?None?else?self.config.use_return_dict ????????transformer_outputs?=?self.model( ????????????input_ids, ????????????attention_mask=attention_mask, ????????????position_ids=position_ids, ????????????past_key_values=past_key_values, ????????????inputs_embeds=inputs_embeds, ????????????use_cache=use_cache, ????????????output_attentions=output_attentions, ????????????output_hidden_states=output_hidden_states, ????????????return_dict=return_dict, ????????) ????????hidden_states?=?transformer_outputs[0] ????????logits?=?self.score(hidden_states) ????????if?input_ids?is?not?None: ????????????batch_size?=?input_ids.shape[0] ????????else: ????????????batch_size?=?inputs_embeds.shape[0] ????????if?self.config.pad_token_id?is?None?and?batch_size?!=?1: ????????????raise?ValueError("Cannot?handle?batch?sizes?>?1?if?no?padding?token?is?defined.") ????????if?self.config.pad_token_id?is?None: ????????????sequence_lengths?=?-1 ????????else: ????????????if?input_ids?is?not?None: ????????????????sequence_lengths?=?(torch.eq(input_ids,?self.config.pad_token_id).long().argmax(-1)?-?1).to( ????????????????????logits.device ????????????????) ????????????else: ????????????????sequence_lengths?=?-1 ????????pooled_logits?=?logits[torch.arange(batch_size,?device=logits.device),?sequence_lengths] ????????loss?=?None ????????if?labels?is?not?None: ????????????labels?=?labels.to(logits.device) ????????????if?self.config.problem_type?is?None: ????????????????if?self.num_labels?==?1: ????????????????????self.config.problem_type?=?"regression" ????????????????elif?self.num_labels?>?1?and?(labels.dtype?==?torch.long?or?labels.dtype?==?torch.int): ????????????????????self.config.problem_type?=?"single_label_classification" ????????????????else: ????????????????????self.config.problem_type?=?"multi_label_classification" ????????????if?self.config.problem_type?==?"regression": ????????????????loss_fct?=?MSELoss() ????????????????if?self.num_labels?==?1: ????????????????????loss?=?loss_fct(pooled_logits.squeeze(),?labels.squeeze()) ????????????????else: ????????????????????loss?=?loss_fct(pooled_logits,?labels) ????????????elif?self.config.problem_type?==?"single_label_classification": ????????????????loss_fct?=?CrossEntropyLoss() ????????????????loss?=?loss_fct(pooled_logits.view(-1,?self.num_labels),?labels.view(-1)) ????????????elif?self.config.problem_type?==?"multi_label_classification": ????????????????loss_fct?=?BCEWithLogitsLoss() ????????????????loss?=?loss_fct(pooled_logits,?labels) ????????if?not?return_dict: ????????????output?=?(pooled_logits,)?+?transformer_outputs[1:] ????????????return?((loss,)?+?output)?if?loss?is?not?None?else?output ????????return?SequenceClassifierOutputWithPast( ????????????loss=loss, ????????????logits=pooled_logits, ????????????past_key_values=transformer_outputs.past_key_values, ????????????hidden_states=transformer_outputs.hidden_states, ????????????attentions=transformer_outputs.attentions, ????????)?
?
三,訓(xùn)練模型
下面,我們來(lái)訓(xùn)練一個(gè)LlamaForCausalLM 實(shí)現(xiàn)兩數(shù)之和的任務(wù)。
?
?
config?=?LlamaConfig( ????vocab_size=len(vocab), ????hidden_size=512, ????intermediate_size=2752, ????num_hidden_layers=8, ????num_attention_heads=16, ????num_key_value_heads=4, ????rope_scaling?=?None, ????hidden_act='silu', ????max_position_embeddings=128, ????initializer_range=0.02, ????rms_norm_eps=1e-06, ????use_cache=True, ????pad_token_id=0, ????bos_token_id=1, ????eos_token_id=2, ????tie_word_embeddings=False, ????pretraining_tp?=?1, ????max_new_tokens?=?100 )?#試算一下 model?=?LlamaForCausalLM(config) out?=?model.forward(**batch) print(out.loss)?
?
tensor(2.7630, grad_fn=)
?
?
from?torchkeras?import?KerasModel? from?accelerate?import?Accelerator? class?StepRunner: ????def?__init__(self,?net,?loss_fn,?accelerator=None,?stage?=?"train",?metrics_dict?=?None,? ?????????????????optimizer?=?None,?lr_scheduler?=?None ?????????????????): ????????self.net,self.loss_fn,self.metrics_dict,self.stage?=?net,loss_fn,metrics_dict,stage ????????self.optimizer,self.lr_scheduler?=?optimizer,lr_scheduler ????????self.accelerator?=?accelerator?if?accelerator?is?not?None?else?Accelerator()? ????????if?self.stage=='train': ????????????self.net.train()? ????????else: ????????????self.net.eval() ???? ????def?__call__(self,?batch): ???????? ????????#loss ????????with?self.accelerator.autocast(): ????????????loss?=?self.net(**batch).loss ????????#backward() ????????if?self.stage=="train"?and?self.optimizer?is?not?None:???????? ????????????self.accelerator.backward(loss) ????????????if?self.accelerator.sync_gradients: ????????????????self.accelerator.clip_grad_norm_(self.net.parameters(),?1.0) ????????????self.optimizer.step() ????????????if?self.lr_scheduler?is?not?None: ????????????????self.lr_scheduler.step() ????????????self.optimizer.zero_grad() ???????????? ????????all_loss?=?self.accelerator.gather(loss).sum() ???????? ????????#losses?(or?plain?metrics?that?can?be?averaged) ????????step_losses?=?{self.stage+"_loss":all_loss.item()} ???????? ????????#metrics?(stateful?metrics) ????????step_metrics?=?{} ???????? ????????if?self.stage=="train": ????????????if?self.optimizer?is?not?None: ????????????????step_metrics['lr']?=?self.optimizer.state_dict()['param_groups'][0]['lr'] ????????????else: ????????????????step_metrics['lr']?=?0.0 ????????return?step_losses,step_metrics ???? KerasModel.StepRunner?=?StepRunner?keras_model?=?KerasModel(model,loss_fn?=?None, ????????optimizer=torch.optim.AdamW(model.parameters(),lr=3e-5)) #加載?之前訓(xùn)練過(guò)的權(quán)重 ckpt_path?=?'llama_twosum' keras_model.fit(train_data?=?dl_train, ????????????????val_data?=?dl_val, ????????????????epochs=100,patience=5, ????????????????monitor='val_loss',mode='min', ????????????????ckpt_path?=?ckpt_path, ????????????????mixed_precision='fp16' ???????????????)?
?
四,使用模型
?
?
from?transformers.generation.utils?import?GenerationConfig model.generation_config?=?GenerationConfig.from_dict({'num_beams':1, ????????????????????????????'max_new_tokens':100, ????????????????????????????'max_length':200})model.generation_config.num_beams=1 model.generation_config.max_new_tokens?=?100? model.generation_config.max_length=200def?get_ans(tensor)?->"str": ????s?=?"".join([vocab_r[i]?for?i?in?tensor.tolist()]) ????ans?=?s[s.find('=')+1:s.find('')].replace(' ','').replace(' ','') ????return?ans x,y?=?get_data()? print('x:?'+''.join(x).replace('','')) print('y:?'+''.join(y).replace(' ','')) x: 3481340050+90157504501803= y: 90160985841853input_ids?=?torch.tensor([[vocab[i]?for?i?in?x]])? out?=?model.generate(inputs=input_ids)out??
?
tensor([[ 1, ?5, ?6, 10, ?3, ?5, ?6, 12, 12, ?7, 12, 13, 11, 12, ?3, ?7, ?9, ?7,12, ?6, ?7, 12, ?3, 10, 12, ?5, 14, 11, 12, ?3, ?8, 12, 11, 10, ?7, 10,6, ?3, 10, ?7, ?5, ?2, ?2, ?2, ?2, ?2, ?2, ?2, ?2, ?2, ?2, ?2, ?2, ?2,2, ?2, ?2, ?2, ?2, ?2, ?2, ?2, ?2, ?2, 12, ?2, ?2, ?2, ?2, ?2, ?2, ?2,2, 12, ?3, 12, ?3]])
?
?
get_ans(out[0])?
?
'90160985841853'
五,評(píng)估模型
?
?
from?tqdm?import?tqdm? loop?=?tqdm(range(1,201)) correct?=?0 for?i?in?loop: ????x,y?=?get_data()? ????input_ids?=?torch.tensor([[vocab[i]?for?i?in?x]])? ????out?=?model.generate(inputs=input_ids) ????pred?=?get_ans(out[0]) ????gt?=?''.join(y).replace('','') ????if?pred==gt: ????????correct+=1 ????loop.set_postfix(acc?=?correct/i) ???? print("acc=",correct/len(loop)) ?
?
acc= 0.99
漂亮,我們的測(cè)試準(zhǔn)確率達(dá)到了99%!
編輯:好
?
評(píng)論
查看更多