本篇是《Rust與AI》系列的第二篇,上一篇我們主要介紹了本系列的概覽和方向,定下了一個基調(diào)。本篇我們將介紹LLM的基本架構(gòu),我們會以迄今為止使用最廣泛的開源模型LLaMA為例展開介紹。
LLM背景
Rust 本身是不挑 AI 模型的,但是 LLM 是當(dāng)下最熱的方向,我們就從它開始吧,先了解一些非常基礎(chǔ)的背景知識。
Token
LLM 中非常重要的一個概念是 Token,我們輸入給 LLM 和它輸出的都是 Token。Token 在這里可以看做語言的基本單位,中文一般是詞或字(其實字也是詞)。比如:”我們喜歡 Rust 語言“,Token 化后會變成類似 ”我們/喜歡/Rust/語言“ 這樣的四個詞,可以理解為四個 Token。
給定一段任意的自然語言文本,我們可以用一個分詞器(Tokenizer)將其 Token 化成一個個連續(xù)的 Token。這些 Token 接下來就可以映射成一個個數(shù)字,其實是在詞表中的索引,索引進(jìn)而可以找到一個稠密向量,用來表示該位置 Token 的語義輸入。
我們以剛剛的”我們喜歡 Rust 語言“為例,假定已有詞表如下。
…… 1000 Rust …… 2000 我們 2001 喜歡 2002 語言 ……
注意,前面的數(shù)字是行號,并不是詞表內(nèi)容。剛剛那句話其實就是 [2000, 2001, 1000, 2002],這就是 LLM 的輸入。LLM 拿到這些 ID 后,會在一個非常大的表里查找對應(yīng)的稠密向量。這個非常大的表就是詞表,大小是:詞表大小N × 模型維度,如下所示。
…… 1000 0.9146, 0.066, 0.4469, 0.3867, 0.3221, 0.6566, 0.2895, 。.. …… 2000 0.5702, 0.9579, 0.0992, 0.9667, 0.5013, 0.4752, 0.1397, 。.. 2001 0.2896, 0.7756, 0.6392, 0.4034, 0.3267, 0.9643, 0.4311, 。.. 2002 0.4344, 0.6662, 0.3205, 0.3929, 0.6418, 0.6707, 0.2414, 。.. ……
也就是說,輸入”我們喜歡Rust語言“這句話,我們實際傳遞給模型的其實是一個 4×Dim 的矩陣,這里的 4 一般也叫 Sequence Length。
我們可以暫時把模型看作一個函數(shù) f(x),輸入一個 Sequence Length × Dim 的矩陣,經(jīng)過模型 f(x) 各種運(yùn)算后會輸出 Sequence Length × Vocabulary Size 大小的一個概率分布。有了概率分布就可以采樣一個 Token ID(基于上下文最后一個 Token ID 的分布),這個 ID 也就是給定當(dāng)前上下文(”我們喜歡Rust語言“)時生成的下一個 Token。接下來就是把這個 ID 拼在剛剛的 4 個 ID 后面(輸入變成 5 個 ID),繼續(xù)重復(fù)這個過程。
生成
如上所言,生成過程就是從剛剛的概率分布中 “選擇” 出一個 Token ID 作為下一個 Token ID。選擇的方法可以很簡單,比如直接選擇概率最大的,此時就是 Greedy Search,或 Greedy Decoding。
不過我們平時用到大模型時一般都用的是采樣的方法,也就是基于概率分布進(jìn)行采樣。拋硬幣也是一種采樣,按概率分布(0.5,0.5)進(jìn)行采樣,但假設(shè)正面比較重,概率分布就可能變成了(0.8,0.2)了?;?Vocabulary Size 個概率值進(jìn)行采樣也是類似的,只不過括號里的值就是詞表大小那么多個。
top_p/top_k 采樣是概率值太多了,大部分都是概率很小的 Token,為了避免可能采樣到那些概率很低的 Token(此時生成的結(jié)果可能很不連貫),干脆就只從前面的 Token 里挑。
top_k 就是把 Token 按概率從大到小排序,然后從前 k 個里面選擇(采用)下一個 Token;top_p 也是把 Token 按概率從大到小排序,不過是從累積概率大于 p 的 Token 里選。就是這么簡單。
這里有個小細(xì)節(jié)需要說明,因為選擇了 top_p/k,所以這些備選的 Token 需要重新計算概率,讓它們的概率和為 1(100%)。
開源代表——LLaMA
接下來,我們把重心放在函數(shù) f(x) 上,以最流行的開源 LLM——LLaMA 為例,簡單介紹一下模型的結(jié)構(gòu)和參數(shù)。
結(jié)構(gòu)
LLaMA 的結(jié)構(gòu)相對而言比較簡單,如果我們忽略其中的很多細(xì)節(jié),只考慮推理過程,看起來如下圖所示。
圖中 [] 中的是該位置的張量 shape,B 表示 Batch Size,一般時候都是批量丟給 GPU 計算的,L 就是 Sequence Length,D 就是上面提到的 Dim。這是一個簡化了的架構(gòu)圖,但是足以清晰地表達(dá)模型了。
兩個 Hidden states(以下簡稱 HS),外面(之上和之下)的部分我們前面已經(jīng)提到過了(注意上面部分,[B,L,D] 會先變成 [B,L,VS],然后取最后一個 Token 就得到了 [B,1,VS]),上面的 HS 會傳回到 Block 里面,重復(fù) N 次,N 就是模型的層數(shù)。接下來我們就把重點放在中間這個 Block 里。
每個 Block 包括兩個主要模塊,一個 MHA(Multi-Head Attention)模塊,一個 FFN(Feedforward Network)模塊,每次傳給模塊之前都需要 Normalization,這個叫 Pre-Normalization,一般用來穩(wěn)定訓(xùn)練。另外,每個模塊結(jié)束后會疊加模塊之前的輸入,這個叫殘差連接,一般能加速收斂。
接下來是 MHA 和 FFN,先看 FFN 模塊,它的大概流程如下(@ 表示矩陣/張量乘法)。
z1 = ns @ up_weights z2 = ns @ gate_weights z3 = z1 * silu(z2) z4 = z3 @ down_weights
整體來看是先將網(wǎng)絡(luò)擴(kuò)大再收縮,擴(kuò)大時增加了一個激活處理。silu 函數(shù)大概長這樣:
等價于只激活了一部分參數(shù),這個非線性激活非常重要,可以讓模型學(xué)習(xí)到更豐富的知識和表達(dá)。
再就是 MHA 模塊了,大概流程如下(為了更直觀,去掉了 Batch Size 和 Softmax)。
q = ns @ q_weights # (L, D) @ (D, D) = (L, D) k = ns @ k_weights # (L, D) @ (D, D) = (L, D) v = ns @ v_weights # (L, D) @ (D, D) = (L, D) q = q.reshape(L, NH, HD) k = k.reshape(L, NH, HD) v = v.reshpae(L, NH, HD) attn = q.trans(NH, L, HD) @ k.trans(NH, HD, L) # (NH, L, HD) @ (NH, HD, L) = (NH, L, L) v = attn @ v.trans(NH, L, HD) # (NH, L, L) @ (NH, L, HD) = (NH, L, HD) v = v.reshpe(L, NH*HD) # (L, D)
其中,NH 表示 Attention 的 Head 數(shù),HD 表示 Head 的維度。因為有 NH 個 Head,所以叫 Multi-Head,但其實我們看上面的過程,在實際計算的時候它們是合并一起算的。我們不妨只看一個 Head,如下所示。
q = ns @ hq_weights # (L, D) @ (D, HD) = (L, HD) k = ns @ hk_weights # (L, D) @ (D, HD) = (L, HD) v = ns @ hv_weights # (L, D) @ (D, HD) = (L, HD) attn = q @ k.T # (L, HD) @ (HD, L) = (L, L) v = attn @ v # (L, L) @ (L, HD) = (L, HD)
上面的多個 Head 的 v 就是下面的每個 Head 的 v 拼接起來的。
Multi-Head 是多個注意力頭去執(zhí)行 Attention,其思想是讓每個 Head 去捕獲不同角度/層面的 Attention,這些角度/層面是什么?不是特別清楚(但一定是某種特征),但我們可以通過 Attention 的權(quán)重看出外在 Token 級別的注意力,知道每個注意力 Head,哪些 Token 之間有比較強(qiáng)的連接。
參數(shù)
關(guān)于 f(x) 我們已經(jīng)介紹完了,可以發(fā)現(xiàn)這個函數(shù)其實還是有點復(fù)雜的。接下來,我們看看參數(shù)情況。
對一個一元一次方程(比如 f(x) = ax + b)來說,參數(shù)就兩個:a 和 b,但對于 LLM 來說,參數(shù)就非常多了,目前常用的是 7B、13B、20B 的級別,也就是 70億、130億和 200億的參數(shù)規(guī)模。
在神經(jīng)網(wǎng)絡(luò)中,可以把矩陣乘法看作是多元一次方程組的計算過程,輸入的 Hidden State 維度是 D,就表示未知變量的維度是 D,也就是 D 元一次方程組。
以前面的但 Head Attention 的 q 為例,q_weights 是一個 DxHD 的參數(shù)矩陣,我們把 D 和 HD 設(shè)置的小一點(假設(shè)為4和2),看一個具體的例子。
torch.manual_seed(42) w = nn.Linear(4, 2, bias=False) # D=4, HD=2 hs = torch.rand((3, 4)) # L=3, D=4 q = hs @ w.weight.T “”“ hq_weights = w.weight.T = tensor([[ 0.3823, -0.1096], [ 0.4150, 0.1009], [-0.1171, -0.2434], [ 0.4593, 0.2936]]) hs = tensor([[0.9408, 0.1332, 0.9346, 0.5936], [0.8694, 0.5677, 0.7411, 0.4294], [0.8854, 0.5739, 0.2666, 0.6274]]) q = tensor([[ 0.5781, -0.1428], [ 0.6784, -0.0923], [ 0.8336, 0.0803]]) ”“”
這個例子除了維度小一點,其他邏輯是一樣的。它對應(yīng)這么一個多元方程組。
w11*x11 + w21*x12 + w31*x13 + w41*x14 = y11 w12*x11 + w22*x12 + w32*x13 + w42*x14 = y12 w11*x21 + w21*x22 + w31*x23 + w41*x24 = y21 w12*x21 + w22*x22 + w32*x23 + w42*x24 = y22 w11*x31 + w21*x32 + w31*x33 + w41*x34 = y31 w12*x31 + w22*x32 + w32*x33 + w42*x34 = y32
其中 x 就是 hs,w 就是 hq_weights,寫成數(shù)學(xué)表達(dá)式大概就是下面的這樣。 $$ left[egin{array}{llll} x_{11} & x_{12} & x_{13} & x_{14} x_{21} & x_{22} & x_{23} & x_{24} x_{31} & x_{32} & x_{33} & x_{34} end{array} ight] imesleft[egin{array}{ll} w_{11} & w_{12} w_{21} & w_{22} w_{31} & w_{32} w_{41} & w_{42} end{array} ight]=left[egin{array}{ll} y_{11} & y_{12} y_{21} & y_{22} y_{31} & y_{32} end{array} ight] $$ 對于這樣的一個 Linear 來說,參數(shù)量就是 2×4=8 個?,F(xiàn)在讓我們看看 LLaMA,就按詞表大小=32000,維度=4096來計算。
首先是 Embedding 和 LM Head(就是映射到 32000 個 Token 的那個參數(shù)),它們是一樣的,都是 32000×4096,有時候這兩個地方的參數(shù)也可以設(shè)計成共享的,LM Head 前面也有一個 Normalization,4096 個參數(shù)。
然后是 Block,MHA 的 qkvo 是 4 個 4096×4096 的矩陣,F(xiàn)FN 的 gate、up、down 是 11008×4096 的矩陣,再加上兩個 Normalization, 4096×2 個參數(shù)。每個 Block 參數(shù)量為 4096×(4096×4+11008×3+2)。
這樣得到所有的參數(shù)總和為:32000*4096*2 + 4096 +(4096*(4096*4+11008*3+2))*32 = 6738415616,67億多的樣子,也就是常說的 7B。
Rust與LLaMA
終于來到了 Rust,之所以前面鋪墊那么多,是因為如果我們完全不熟悉模型的基本結(jié)構(gòu)和執(zhí)行過程,這個代碼看起來就會知其然而不知其所以然。當(dāng)然,即便了解了基本結(jié)構(gòu),里面也有一些細(xì)節(jié)需要單獨介紹,不過我們會放在后續(xù)的內(nèi)容。
只看上面的內(nèi)容,我們可以發(fā)現(xiàn) LLM 模型的結(jié)構(gòu)其實不算特別復(fù)雜,而且其中涉及到大量的矩陣運(yùn)算(至少占到 80% 以上)。關(guān)于矩陣運(yùn)算以及相關(guān)的優(yōu)化,我們也會在后面慢慢涉及。
LLaMA 的 Rust 實現(xiàn)有很多個版本,本次選擇的是來自 karpathy/llama2.c: Inference Llama 2 in one file of pure C 的 Rust 實現(xiàn)的版本中的:danielgrittner/llama2-rs: LLaMA2 + Rust,而且我們暫時只會涉及模型基礎(chǔ)結(jié)構(gòu)部分,其中涉及一些特別的細(xì)節(jié)會簡單解釋,不深入展開。
配置
首先是配置,如下所示。
struct Config { dim: usize, // transformer dimension hidden_dim: usize, // for ffn layers n_layers: usize, // number of layers n_heads: usize, // number of query heads head_size: usize, // size of each head (dim / n_heads) n_kv_heads: usize, // number of key/value heads shared_weights: bool, vocab_size: usize, // vocabulary size seq_len: usize, // max. sequence length }
dim 就是上面一直說的 Dim,hidden_dim 僅在 FFN 層,因為 FFN 層需要先擴(kuò)大再縮小。n_heads 和 n_kv_heads 是 Query 的 Head 數(shù)和 KV 的 Head 數(shù),簡單起見可以認(rèn)為它們是相等的。如果我們加載 karpathy 的 15M 的模型,結(jié)果如下。
Config { dim: 288, hidden_dim: 768, n_layers: 6, n_heads: 6, head_size: 48, n_kv_heads: 6, shared_weights: true, vocab_size: 32000, seq_len: 256 }
shared_weights 就是上面提到的 Embedding 和 LM Head 是否共享參數(shù)。
Tokenizer 的功能我們暫且略過,目前只需知道它負(fù)責(zé)將文本轉(zhuǎn)為 ID 列表(encode)以及把 ID 列表轉(zhuǎn)為文本(decode)。
參數(shù)
接下來看模型參數(shù),如下所示。
struct TransformerWeights { // Token Embedding Table token_embedding_table: Vec《f32》, // (vocab_size, dim) // Weights for RMSNorm rms_att_weight: Vec《f32》, // (layer, dim) rms_ffn_weight: Vec《f32》, // (layer, dim) // Weights for matmuls in attn wq: Vec《f32》, // (layer, dim, dim) wk: Vec《f32》, // (layer, dim, dim) wv: Vec《f32》, // (layer, dim, dim) wo: Vec《f32》, // (layer, dim, dim) // Weights for ffn w1: Vec《f32》, // (layer, hidden_dim, dim) w2: Vec《f32》, // (layer, dim, hidden_dim) w3: Vec《f32》, // (layer, hidden_dim, dim) // final RMSNorm rms_final_weights: Vec《f32》, // (dim) // freq_cis for RoPE relatively positional embeddings freq_cis_real: Vec《f32》, // (seq_len, head_size/2) freq_cis_imag: Vec《f32》, // (seq_len, head_size/2) // (optional) classifier weights for the logits, on the last layer wcls: Vec《f32》, // (vocab_size, dim) }
上面的參數(shù)應(yīng)該都比較直觀,我們不太熟悉的應(yīng)該是 freq_ 開頭的兩個參數(shù),它們是和位置編碼有關(guān)的參數(shù),也就是說,我們每次生成一個 Token 時,都需要傳入當(dāng)前位置的位置信息。
位置編碼在 Transformer 中是比較重要的,因為 Self Attention 本質(zhì)上是無序的,而語言的先后順序在有些時候是很重要的,比如 “我喜歡你” 和 “你喜歡我”,“你” 和 “我” 的順序不同,語義也不同。但時候很多語義又不太響影我們解理語義,不妨再仔細(xì)讀一下剛剛這半句話。你看文本順序雖然變了,但你讀起來毫無障礙。這也是為什么會有研究說不要位置編碼語言模型也可以,但效果應(yīng)該是不如加了位置編碼的。
模型創(chuàng)建好后,接下來就是加載參數(shù)和執(zhí)行推理。加載參數(shù)要看模型文件的格式設(shè)計,本項目來自 karpathy 的 C 代碼,模型文件被安排成了 bin 文件,按規(guī)定的格式讀取即可,核心代碼如下。
fn byte_chunk_to_vec《T》(byte_chunk: &[u8], number_elements: usize) -》 Vec《T》 where T: Clone, { unsafe { // 獲取起始位置的原始指針 let data = byte_chunk.as_ptr() as *const T; // 從原始指針創(chuàng)建一個 T 類型的切片,注意number_elements是element的數(shù)量,而不是bytes // 這句是 unsafe 的 let slice_data: &[T] = std::from_raw_parts(data, number_elements); // 將切片轉(zhuǎn)為 Vec,需要 T 可以 Clone slice_data.to_vec() } }
byte_chunk 表示原始的字節(jié)切片,number_elements 表示結(jié)果向量中元素的個數(shù),T 有 Clone 的 Trait 約束,表示 T 必須實現(xiàn)該 Trait,也就是 T 必須能夠使用 Clone 方法。其他解釋已經(jīng)在代碼中給出了注釋,不再贅述。
加載模型就是讀取原始的 bin 文件并指定對應(yīng)的參數(shù)大小,我們以 Token Embedding 參數(shù)為例,如下所示。
let token_embedding_table_size = config.vocab_size * config.dim; // offset.。 表示從 offset 往后的所有元素 let token_embedding_table: Vec《f32》 = byte_chunk_to_vec(&mmap[offset.。], token_embedding_table_size);
類似這樣就可以依次把模型參數(shù)讀取進(jìn)來了。
模型
接下來就是最復(fù)雜的模型部分了。這里最大的不同是 Token by Token 的處理,而不是給定一個上下文生成下一個 Token。我們看一下基本的 Struct,如下所示。
struct LLaMA2《‘a(chǎn)》 { // buffers for current activations x: Vec《f32》, // activation at current timestep (dim,) xb: Vec《f32》, // same, but inside a residual branch (dim,) xb2: Vec《f32》, // additional buffer (dim,) hb: Vec《f32》, // buffer for hidden dimension in the ffn (hidden_dim,) hb2: Vec《f32》, // buffer for hidden dimension in the ffn (hidden_dim,) q: Vec《f32》, // query (dim,) k: Vec《f32》, // key (dim,) v: Vec《f32》, // value (dim,) att: Vec《f32》, // attention scores (n_heads, seq_len) logits: Vec《f32》, // output logits (vocab_size,) // kv cache key_cache: Vec《f32》, // (layer, seq_len, dim) value_cache: Vec《f32》, // (layer, seq_len, dim) // weights & config transformer: &’a TransformerWeights, config: &‘a(chǎn) Config, }
最后兩個參數(shù)我們上面已經(jīng)介紹過了,其他參數(shù)都是模型推理過程中需要用到的中間結(jié)果和最初的輸入,以及最終的結(jié)果,它們均被初始化成 0。至于為什么有些值是多個(比如 xb、hb等),是因為 Block 里面涉及到殘差連接,需要額外保存一個輸入。
現(xiàn)在我們從 forward 開始,方法如下。
fn forward(&mut self, token: usize, pos: usize) { // fetch the token embedding self.x.copy_from_slice( &self.transformer.token_embedding_table [(token * self.config.dim)。.((token + 1) * self.config.dim)], ); // Note: here it always holds that seqlen == 1 in comparison to the PyTorch implementation for l in 0..self.config.n_layers { self.layer(l, pos); } // final RMSNorm rmsnorm( self.x.as_mut_slice(), self.transformer.rms_final_weights.as_slice(), ); // generate logits, i.e., map activations from dim to vocab_size matmul( self.logits.as_mut_slice(), // out: (vocab_size,) self.transformer.wcls.as_slice(), // W: (vocab_size, dim) self.x.as_slice(), // x: (dim,) ); }
這塊代碼是推理的全流程,一共四個步驟:取 Embedding、逐層計算、Normalization、映射到詞表大小的 logits(后續(xù)會基于此轉(zhuǎn)為概率分布)。
Embedding 是直接從參數(shù)里 copy 出對應(yīng)索引的參數(shù),無序贅述。
Normalization 用的是 RMS(Root Mean Square)Normalization,基本公式如下。 $$ x’i = frac{x_i} {sqrt{sum{i=1}^N x_i}} * w_i $$ 它是標(biāo)準(zhǔn) Normalization 的簡單形式,但效果尚可,其代碼如下。
fn rmsnorm(x: &mut [f32], weight: &[f32]) { let size = x.len(); let squared_sum = x.iter().fold(0.0, |acc, x| acc + x * x); let rms = 1. / (squared_sum / size as f32).sqrt(); x.iter_mut() .zip(weight.iter()) .for_each(|(x, w)| *x *= rms * w); }
代碼一目了然,先一個 reduce,然后開方取倒數(shù),接著就是遍歷計算更新每個參數(shù)值。
最后的矩陣乘法比較標(biāo)準(zhǔn),輸入的 Hidden State(x)因為只有一個 Token,所以可以看成向量,長度為 Dim,與 LM Head 矩陣乘法后就得到一個詞表大小的輸出值,后續(xù)可以歸一化成概率值(即概率分布)。矩陣乘法代碼如下(準(zhǔn)確來說是向量和矩陣乘法)。
fn matmul(target: &mut [f32], w: &[f32], x: &[f32]) { let in_dim = x.len(); target.par_iter_mut().enumerate().for_each(|(i, t)| { let row_offset = i * in_dim; *t = x .iter() .zip(w[row_offset.。].iter()) .fold(0.0, |result, (x, w)| result + x * w); }); }
這里需要注意的是 offset,因為參數(shù)是一個 Vec 存儲的一維數(shù)組,要按二維取值,需要每次跳過對應(yīng)數(shù)量的參數(shù)。剩下的就很清晰了,最終的結(jié)果會存儲到 target,也就是 self.logits,進(jìn)而會轉(zhuǎn)為概率分布。
我們把重心放在中間的逐層計算上,LLM 的核心也在這里。先看 layer 的代碼,如下所示。
fn layer(&mut self, layer: usize, pos: usize) { // Note: we leave the buffer x as it is because we need it for the residual connection rmsnorm_with_dest( self.xb.as_mut_slice(), self.x.as_slice(), &self.transformer.rms_att_weight [layer * self.config.dim.。(layer + 1) * self.config.dim], ); self.attn(layer, pos); // residual connection add_vectors(self.x.as_mut_slice(), self.xb2.as_slice()); // Note: we leave the buffer x as it is because we need it for the residual connection rmsnorm_with_dest( self.xb.as_mut_slice(), self.x.as_slice(), &self.transformer.rms_ffn_weight [layer * self.config.dim.。(layer + 1) * self.config.dim], ); self.ffn(layer); // residual connection add_vectors(self.x.as_mut_slice(), self.xb.as_slice()); }
非常標(biāo)準(zhǔn)的流程(可回看前面的架構(gòu)圖),先歸一化,然后 MHA,殘差連接,再歸一化,F(xiàn)FN,殘差連接。歸一化的代碼剛剛已經(jīng)看過了,這里唯一的不同是將輸出放到第一個參數(shù)(即 self.xb)里。add_vectors 就是對應(yīng)元素值求和,結(jié)果放到第一個參數(shù),這個比較簡單,我們就不放代碼了。重點就是 ffn 和 attn,它們內(nèi)部涉及大量矩陣乘法,我們開始。
先看 ffn,它比較簡單,主要是幾個矩陣乘法加非線性激活,代碼如下。
fn ffn(&mut self, layer: usize) { let weight_from = layer * self.config.hidden_dim * self.config.dim; let weight_to = (layer + 1) * self.config.hidden_dim * self.config.dim; // gate z2 matmul( self.hb.as_mut_slice(), // out: (hidden_dim,) &self.transformer.w1[weight_from..weight_to], // W: (hidden_dim, dim) self.xb.as_slice(), // x: (dim,) ); // up z1 matmul( self.hb2.as_mut_slice(), // out: (hidden_dim,) &self.transformer.w3[weight_from..weight_to], // W: (hidden_dim, dim) self.xb.as_slice(), // x: (dim,) ); // z3 for i in 0..self.config.hidden_dim { self.hb[i] = silu(self.hb[i]) * self.hb2[i]; } // down z4 matmul( self.xb.as_mut_slice(), // out: (hidden_dim,) &self.transformer.w2[weight_from..weight_to], // W: (hidden_dim, dim) self.hb.as_slice(), // x: (dim,) ); }
這個過程和我們《開源代表——LLaMA 結(jié)構(gòu)》一節(jié)中是一一對應(yīng)的,涉及到的主要是剛剛介紹過的 matmul 和一個 silu,后者我們之前看過它的圖像,代碼如下。
fn silu(x: f32) -》 f32 { x / (1.0 + (-x).exp()) }
表達(dá)式如下所示。 $$ ext{SiLU}(x) = frac{x}{1 + e^{-x}} $$ 好了,最后我們把重心放在 attn 這個方法上,由于逐 Token 生成時,Query 是當(dāng)前 Token,這沒問題,但 Key 和 Value(Attention 里面的 K和V)是需要歷史 Token 的(不然怎么算注意力)。常見的做法就是把歷史過程中的 K 和 V 緩存起來,每次生成時順便更新緩存,這樣下次生成時拿到的就是之前的所有 K 和 V。
先看一下基本的代碼流程,如下所示。
fn attn(&mut self, layer: usize, pos: usize) { // qkv matmuls self.attn_qkv_matmuls(layer); // apply RoPE rotation to the q and k vectors for each head self.attn_rope(layer, pos); // Multi-head attention with caching self.cache_kv(layer, pos); self.multihead_attn(layer, pos); // wo let weight_from = layer * self.config.dim * self.config.dim; let weight_to = (layer + 1) * self.config.dim * self.config.dim; matmul( self.xb2.as_mut_slice(), // out: (dim,) &self.transformer.wo[weight_from..weight_to], // W: (dim, dim) self.xb.as_slice(), // x: (dim,) ); }
最后的 wo 比較簡單,不再贅述。一開始的 qkv 也比較簡單,都是矩陣乘法,如下所示。
fn attn_qkv_matmuls(&mut self, layer: usize) { let weight_from = layer * self.config.dim * self.config.dim; let weight_to = (layer + 1) * self.config.dim * self.config.dim; matmul( self.q.as_mut_slice(), // out: (dim,) &self.transformer.wq[weight_from..weight_to], // W: (dim, dim) self.xb.as_slice(), // x: (dim,) ); matmul( self.k.as_mut_slice(), // out: (dim,) &self.transformer.wk[weight_from..weight_to], // W: (dim, dim) self.xb.as_slice(), // x: (dim,) ); matmul( self.v.as_mut_slice(), // out: (dim,) &self.transformer.wv[weight_from..weight_to], // W: (dim, dim) self.xb.as_slice(), // x: (dim,) ); }
還剩下三個方法:attn_rope、cache_kv 和 multihead_attn,我們分別看一下。
第一個用來加入位置信息,參數(shù)是一開始算好的,這里直接取出對應(yīng)位置的值進(jìn)行計算。代碼如下所示。
fn attn_rope(&mut self, layer: usize, pos: usize) { // apply RoPE rotation to the q and k vectors for each head let freq_cis_real_offset = pos * self.config.head_size / 2; let freq_cis_imag_offset = pos * self.config.head_size / 2; for i in (0..self.config.dim).step_by(2) { let q0 = self.q[i]; let q1 = self.q[i + 1]; let k0 = self.k[i]; let k1 = self.k[i + 1]; let cos = self.transformer.freq_cis_real [freq_cis_real_offset + (i % self.config.head_size) / 2]; let sin = self.transformer.freq_cis_imag [freq_cis_imag_offset + (i % self.config.head_size) / 2]; self.q[i] = q0 * cos - q1 * sin; self.q[i + 1] = q1 * cos + q0 * sin; self.k[i] = k0 * cos - k1 * sin; self.k[i + 1] = k1 * cos + k0 * sin; } }
這部分代碼就是把位置信息注入到 Q 和 K 中,其理論分析比較復(fù)雜,此處不展開。
cache_kv 比較簡單,直接把當(dāng)前的 K 和 V 存起來即可,如下所示。
fn cache_kv(&mut self, layer: usize, pos: usize) { // cache the key, value for the current timestep (pos) let layer_offset = layer * self.config.seq_len * self.config.dim; // offset to get to the cache of the current layer let cache_from = layer_offset + pos * self.config.dim; let cache_to = layer_offset + (pos + 1) * self.config.dim; self.key_cache[cache_from..cache_to].copy_from_slice(&self.k.as_slice()); self.value_cache[cache_from..cache_to].copy_from_slice(&self.v.as_slice()); }
因為我們不確定用戶生成的 Token 長度,所以就把最大長度(seq_len)的所有位置都占上,因為是按層存的,每一層都有計算,所以需要層的 ID。每一層、每個位置都緩存 dim 個中間結(jié)果。
最后就是最重要的 multihead_attn 了,這里面的主要邏輯是計算 attention 分?jǐn)?shù),然后得到 attention 之后的結(jié)果,代碼如下。
fn multihead_attn(&mut self, layer: usize, pos: usize) { // offset to get to the cache of the current layer let layer_offset_for_cache = layer * self.config.seq_len * self.config.dim; // 縮放因子 let sqrt_d = (self.config.head_size as f32).sqrt(); // att 和 xb 分別按指定大小切塊 // attn_scores每一塊是seq_len長度,共n_head(NH)塊,即按 head 處理 // xb每一塊是head_size長度,共n_head(NH)塊 self.att.par_chunks_exact_mut(self.config.seq_len) .zip(self.xb.par_chunks_exact_mut(self.config.head_size)) .enumerate() .for_each(|(h, (attn_scores, xb))| { assert_eq!(attn_scores.len(), self.config.seq_len); assert_eq?。▁b.len(), self.config.head_size); // get query vector of the timestep pos for the current head // 第h個head,Q是當(dāng)前Token,(1, HD) let q_from = h * self.config.head_size; let q_to = (h + 1) * self.config.head_size; let q = &self.q[q_from..q_to]; // Compute temp = (K * q_pos) / sqrt(dim) // K和V是要包含歷史Token,(L, HD) // q @ k.T 得到的是 (1,HD)@(HD,L)=(1, L) 大小的 attention score // 這里循環(huán)L(pos)次,所以每一個位置的值是 (1,HD)@(HD,1)=(1,1),即點積 for t in 0.。=pos { // key_cache[l, t] let timestep_and_layer_offset = layer_offset_for_cache + t * self.config.dim; // for the current key, select the correct range which corresponds to the current head let key_vector_from = timestep_and_layer_offset + h * self.config.head_size; let key_vector_to = timestep_and_layer_offset + (h + 1) * self.config.head_size; let key_vector = &self.key_cache[key_vector_from..key_vector_to]; attn_scores[t] = inner_product(q, key_vector) / sqrt_d; } // softmax the scores to get attention weights, from 0..pos inclusively // 歸一化得到概率 softmax(&mut attn_scores[。.(pos + 1)]); // Compute temp2^T * V // 計算加權(quán)的v // attention是 (1,L),V是(L,HD),每個HD的權(quán)重是attention[i] xb.fill(0.0); for t in 0.。=pos { // value_cache[l, t] let timestep_and_layer_offset = layer_offset_for_cache + t * self.config.dim; // for the current value, select the correct range which corresponds to the current head let value_vector_from = timestep_and_layer_offset + h * self.config.head_size; let value_vector_to = timestep_and_layer_offset + (h + 1) * self.config.head_size; let value_vector = &self.value_cache[value_vector_from..value_vector_to]; // weighted sum with attention scores as weights let attention_weight = attn_scores[t]; for i in 0..self.config.head_size { xb[i] += attention_weight * value_vector[i]; } } }); }
上面的過程是分 Head 計算的,需要我們深刻理解前面《開源代表——LLaMA 結(jié)構(gòu)》一小節(jié)的內(nèi)容,具體解釋可以參考代碼里的注釋。值得一提的是,分 Head 計算是并行的。
另外,有個新方法 inner_product 是點積,也就是對應(yīng)元素相乘后求和,代碼如下。
fn inner_product(x: &[f32], y: &[f32]) -》 f32 { zip(x, y).fold(0.0, |acc, (a, b)| acc + a * b) }
比較簡單,不再贅述。
生成
最后就是生成(或 Decoding)過程。代碼略有不同,我們先看下。
fn generate(&mut self, prompt_tokens: &Vec《usize》, n_tokens: usize, temperature: f32) -》 Vec《usize》 { let mut tokens = vec?。郏? tokens.reserve(n_tokens); let mut token = BOS_TOKEN; tokens.push(token); // forward through the prompt to fill up the KV-cache! for (pos, prompt_token) in prompt_tokens.iter().enumerate() { self.forward(token, pos); token = *prompt_token; tokens.push(token); } // complete the prompt for pos in prompt_tokens.len()。.(n_tokens - 1) { self.forward(token, pos); if temperature == 0.0 { token = argmax(self.logits.as_slice()); } else { // Apply temperature and then sample. self.logits.iter_mut().for_each(|p| *p = *p / temperature); softmax(&mut self.logits.as_mut_slice()); token = sample(self.logits.as_slice()); } tokens.push(token); } tokens }
這里有兩個值得注意的地方。
第一個是推理 Prompt(即第一次輸入時的 Context),此時給定的 Context 是多個 Token 組成的,執(zhí)行該過程目的是填充 KV Cache。
第二個是采樣過程,temperature=0.0 時,就是 Greedy Search,每次返回概率最大位置的 Token;否則,會先應(yīng)用 temperature,然后按照概率分布進(jìn)行采樣。temperature 參數(shù)會平滑概率分布,值越大,平滑力度越大,更有可能生成多樣的結(jié)果。softmax 用來把一系列值歸一化成概率分布(所有值加起來和為 1.0)。我們重點看看這個 sample 方法,它的主要思想是根據(jù)概率分布進(jìn)行采樣,也就是高概率的位置更容易被采樣到,低概率的位置更不容易被采樣到。代碼如下。
fn sample(probs: &[f32]) -》 usize { let mut rng = rand::thread_rng(); let mut cdf = 0.0; let r = rng.gen_range(0.0..1.0); for (i, p) in probs.iter().enumerate() { cdf += p; if cdf 》 r { return i; } } probs.len() - 1 }
隨機(jī)生成 0-1 之間的一個值(均勻分布),計算累積概率,當(dāng)累積概率大于剛剛生成的值時,返回此時的位置。這樣就可以保證是按照概率分布進(jìn)行采樣的。我們舉個具體的例子,如下所示。
// 假設(shè)概率分布為 probs = [0.1, 0.2, 0.1, 0.5, 0.1] // 累積概率為 accu_probs = [0.1, 0.3, 0.4, 0.9, 1.0]
假設(shè)隨機(jī)值為 r,因為它是均勻分布的,所以落在不同區(qū)間的概率與該區(qū)間的長度成正比。我們看上面的累積概率,可以得出如下結(jié)果。
r落在區(qū)間返回 Index
[0, 0.1)0
[0.1, 0.3)1
[0.3, 0.4)2
[0.4, 0.9)3
[0.9, 1.0)4
也就是說返回 Index=3 的概率為 0.5,其他同理。
拿到 Token 向量后只要用 Tokenizer 解碼即可得到生成的文本。
小結(jié)
本文我們首先簡單介紹了 LLM 相關(guān)的背景,著重討論了關(guān)于 Token 和生成過程,這是應(yīng)用 LLM 時非常重要的兩個知識點。然后我們介紹了開源 LLM 的代表——LLaMA 的模型結(jié)構(gòu)和參數(shù),給大家一個整體的感知和認(rèn)識。最后就是 Rust 的實現(xiàn),主要包括配置、參數(shù)、模型和生成四個方面,其中最重要的就是模型部分,模型部分最重要、也最難理解的是 Multi-Head Attention 的計算。主要是因為具體的計算過程都是把矩陣運(yùn)算給展開了,這需要對模型有一定程度的理解。
這種展開的寫法其實是比較底層的實現(xiàn),如果能在上面抽象一層,直接操縱矩陣或張量,那計算起來應(yīng)該會簡單很多。事實上,大部分框架都是這么做的,比如 Python 的 NumPy 、PyTorch等,當(dāng)然 Rust 也有類似的框架,比如 NumPy 對應(yīng)的 ndarray,以及 Rust 版本的深度學(xué)習(xí)框架。使用這些框架時,我們使用的是矩陣/張量(或者叫多維數(shù)組)這個對象,所有的操作也都在這個粒度進(jìn)行,這無疑極大地提高了編程效率。同時,還可以利用這些框架底層的性能優(yōu)化。
不過,有時候當(dāng)我們需要框架暫未支持的更細(xì)致的優(yōu)化、或在一個框架不支持的設(shè)備上運(yùn)行時,這種 Pure X(此處為 Rust)的方式就比較方便靈活了。
總的來說,算法是多樣的,實現(xiàn)更是多樣的,優(yōu)化更更是無止境的,吾輩唯有不斷前行,持續(xù)向上。
審核編輯:黃飛
評論
查看更多