注意力機(jī)制
在整個(gè)注意力過(guò)程中,模型會(huì)學(xué)習(xí)了三個(gè)權(quán)重:查詢、鍵和值。查詢、鍵和值的思想來(lái)源于信息檢索系統(tǒng)。所以我們先理解數(shù)據(jù)庫(kù)查詢的思想。
假設(shè)有一個(gè)數(shù)據(jù)庫(kù),里面有所有一些作家和他們的書籍信息。現(xiàn)在我想讀一些Rabindranath寫的書:
在數(shù)據(jù)庫(kù)中,作者名字類似于鍵,圖書類似于值。查詢的關(guān)鍵詞Rabindranath是這個(gè)問(wèn)題的鍵。所以需要計(jì)算查詢和數(shù)據(jù)庫(kù)的鍵(數(shù)據(jù)庫(kù)中的所有作者)之間的相似度,然后返回最相似作者的值(書籍)。
同樣,注意力有三個(gè)矩陣,分別是查詢矩陣(Q)、鍵矩陣(K)和值矩陣(V)。它們中的每一個(gè)都具有與輸入嵌入相同的維數(shù)。模型在訓(xùn)練中學(xué)習(xí)這些度量的值。
我們可以假設(shè)我們從每個(gè)單詞中創(chuàng)建一個(gè)向量,這樣我們就可以處理信息。對(duì)于每個(gè)單詞,生成一個(gè)512維的向量。所有3個(gè)矩陣都是512x512(因?yàn)閱卧~嵌入的維度是512)。對(duì)于每個(gè)標(biāo)記嵌入,我們將其與所有三個(gè)矩陣(Q, K, V)相乘,每個(gè)標(biāo)記將有3個(gè)長(zhǎng)度為512的中間向量。
接下來(lái)計(jì)算分?jǐn)?shù),它是查詢和鍵向量之間的點(diǎn)積。分?jǐn)?shù)決定了當(dāng)我們?cè)谀硞€(gè)位置編碼單詞時(shí),對(duì)輸入句子的其他部分的關(guān)注程度。
然后將點(diǎn)積除以關(guān)鍵向量維數(shù)的平方根。這種縮放是為了防止點(diǎn)積變得太大或太小(取決于正值或負(fù)值),因?yàn)檫@可能導(dǎo)致訓(xùn)練期間的數(shù)值不穩(wěn)定。選擇比例因子是為了確保點(diǎn)積的方差近似等于1。
然后通過(guò)softmax操作傳遞結(jié)果。這將分?jǐn)?shù)標(biāo)準(zhǔn)化:它們都是正的,并且加起來(lái)等于1。softmax輸出決定了我們應(yīng)該從不同的單詞中獲取多少信息或特征(值),也就是在計(jì)算權(quán)重。
這里需要注意的一點(diǎn)是,為什么需要其他單詞的信息/特征?因?yàn)槲覀兊恼Z(yǔ)言是有上下文含義的,一個(gè)相同的單詞出現(xiàn)在不同的語(yǔ)境,含義也不一樣。
最后一步就是計(jì)算softmax與這些值的乘積,并將它們相加。
可視化圖解
上面邏輯都是文字內(nèi)容,看起來(lái)有一些枯燥,下面我們可視化它的矢量化實(shí)現(xiàn)。這樣可以更加深入的理解。
查詢鍵和矩陣的計(jì)算方法如下
同樣的方法可以計(jì)算鍵向量和值向量。
最后計(jì)算得分和注意力輸出。
簡(jiǎn)單代碼實(shí)現(xiàn)
importtorch
importtorch.nnasnn
fromtypingimportList
defget_input_embeddings(words: List[str], embeddings_dim: int):
# we are creating random vector of embeddings_dim size for each words
# normally we train a tokenizer to get the embeddings.
# check the blog on tokenizer to learn about this part
embeddings= [torch.randn(embeddings_dim) forwordinwords]
returnembeddings
text="I should sleep now"
words=text.split(" ")
len(words) # 4
embeddings_dim=512# 512 dim because the original paper uses it. we can use other dim also
embeddings=get_input_embeddings(words, embeddings_dim=embeddings_dim)
embeddings[0].shape# torch.Size([512])
# initialize the query, key and value metrices
query_matrix=nn.Linear(embeddings_dim, embeddings_dim)
key_matrix=nn.Linear(embeddings_dim, embeddings_dim)
value_matrix=nn.Linear(embeddings_dim, embeddings_dim)
query_matrix.weight.shape, key_matrix.weight.shape, value_matrix.weight.shape# torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])
# query, key and value vectors computation for each words embeddings
query_vectors=torch.stack([query_matrix(embedding) forembeddinginembeddings])
key_vectors=torch.stack([key_matrix(embedding) forembeddinginembeddings])
value_vectors=torch.stack([value_matrix(embedding) forembeddinginembeddings])
query_vectors.shape, key_vectors.shape, value_vectors.shape# torch.Size([4, 512]), torch.Size([4, 512]), torch.Size([4, 512])
# compute the score
scores=torch.matmul(query_vectors, key_vectors.transpose(-2, -1)) /torch.sqrt(torch.tensor(embeddings_dim, dtype=torch.float32))
scores.shape# torch.Size([4, 4])
# compute the attention weights for each of the words with the other words
softmax=nn.Softmax(dim=-1)
attention_weights=softmax(scores)
attention_weights.shape# torch.Size([4, 4])
# attention output
output=torch.matmul(attention_weights, value_vectors)
output.shape# torch.Size([4, 512])
以上代碼只是為了展示注意力機(jī)制的實(shí)現(xiàn),并未優(yōu)化。
多頭注意力
上面提到的注意力是單頭注意力,在原論文中有8個(gè)頭。對(duì)于多頭和單多頭注意力計(jì)算相同,只是查詢(q0-q3),鍵(k0-k3),值(v0-v3)中間向量會(huì)有一些區(qū)別。
之后將查詢向量分成相等的部分(有多少頭就分成多少)。在上圖中有8個(gè)頭,查詢,鍵和值向量的維度為512。所以就變?yōu)榱?個(gè)64維的向量。
把前64個(gè)向量放到第一個(gè)頭,第二組向量放到第二個(gè)頭,以此類推。在上面的圖片中,我只展示了第一個(gè)頭的計(jì)算。
這里需要注意的是:不同的框架有不同的實(shí)現(xiàn)方法,pytorch官方的實(shí)現(xiàn)是上面這種,但是tf和一些第三方的代碼中是將每個(gè)頭分開計(jì)算了,比如8個(gè)頭會(huì)使用8個(gè)linear(tf的dense)而不是一個(gè)大linear再拆解。還記得Pytorch的transformer里面要求emb_dim能被num_heads整除嗎,就是因?yàn)檫@個(gè)
使用哪種方式都可以,因?yàn)樽罱K的結(jié)果都類似影響不大。
當(dāng)我們?cè)谝粋€(gè)head中有了小查詢、鍵和值(64 dim的)之后,計(jì)算剩下的邏輯與單個(gè)head注意相同。最后得到的64維的向量來(lái)自每個(gè)頭。
我們將每個(gè)頭的64個(gè)輸出組合起來(lái),得到最后的512個(gè)dim輸出向量。
多頭注意力可以表示數(shù)據(jù)中的復(fù)雜關(guān)系。每個(gè)頭都能學(xué)習(xí)不同的模式。多個(gè)頭還提供了同時(shí)處理輸入表示的不同子空間(本例:64個(gè)向量表示512個(gè)原始向量)的能力。
多頭注意代碼實(shí)現(xiàn)
num_heads=8
# batch dim is 1 since we are processing one text.
batch_size=1
text="I should sleep now"
words=text.split(" ")
len(words) # 4
embeddings_dim=512
embeddings=get_input_embeddings(words, embeddings_dim=embeddings_dim)
embeddings[0].shape# torch.Size([512])
# initialize the query, key and value metrices
query_matrix=nn.Linear(embeddings_dim, embeddings_dim)
key_matrix=nn.Linear(embeddings_dim, embeddings_dim)
value_matrix=nn.Linear(embeddings_dim, embeddings_dim)
query_matrix.weight.shape, key_matrix.weight.shape, value_matrix.weight.shape# torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512])
# query, key and value vectors computation for each words embeddings
query_vectors=torch.stack([query_matrix(embedding) forembeddinginembeddings])
key_vectors=torch.stack([key_matrix(embedding) forembeddinginembeddings])
value_vectors=torch.stack([value_matrix(embedding) forembeddinginembeddings])
query_vectors.shape, key_vectors.shape, value_vectors.shape# torch.Size([4, 512]), torch.Size([4, 512]), torch.Size([4, 512])
# (batch_size, num_heads, seq_len, embeddings_dim)
query_vectors_view=query_vectors.view(batch_size, -1, num_heads, embeddings_dim//num_heads).transpose(1, 2)
key_vectors_view=key_vectors.view(batch_size, -1, num_heads, embeddings_dim//num_heads).transpose(1, 2)
value_vectors_view=value_vectors.view(batch_size, -1, num_heads, embeddings_dim//num_heads).transpose(1, 2)
query_vectors_view.shape, key_vectors_view.shape, value_vectors_view.shape
# torch.Size([1, 8, 4, 64]),
# torch.Size([1, 8, 4, 64]),
# torch.Size([1, 8, 4, 64])
# We are splitting the each vectors into 8 heads.
# Assuming we have one text (batch size of 1), So we split
# the embedding vectors also into 8 parts. Each head will
# take these parts. If we do this one head at a time.
head1_query_vector=query_vectors_view[0, 0, ...]
head1_key_vector=key_vectors_view[0, 0, ...]
head1_value_vector=value_vectors_view[0, 0, ...]
head1_query_vector.shape, head1_key_vector.shape, head1_value_vector.shape
# The above vectors are of same size as before only the feature dim is changed from 512 to 64
# compute the score
scores_head1=torch.matmul(head1_query_vector, head1_key_vector.permute(1, 0)) /torch.sqrt(torch.tensor(embeddings_dim//num_heads, dtype=torch.float32))
scores_head1.shape# torch.Size([4, 4])
# compute the attention weights for each of the words with the other words
softmax=nn.Softmax(dim=-1)
attention_weights_head1=softmax(scores_head1)
attention_weights_head1.shape# torch.Size([4, 4])
output_head1=torch.matmul(attention_weights_head1, head1_value_vector)
output_head1.shape# torch.Size([4, 512])
# we can compute the output for all the heads
outputs= []
forhead_idxinrange(num_heads):
head_idx_query_vector=query_vectors_view[0, head_idx, ...]
head_idx_key_vector=key_vectors_view[0, head_idx, ...]
head_idx_value_vector=value_vectors_view[0, head_idx, ...]
scores_head_idx=torch.matmul(head_idx_query_vector, head_idx_key_vector.permute(1, 0)) /torch.sqrt(torch.tensor(embeddings_dim//num_heads, dtype=torch.float32))
softmax=nn.Softmax(dim=-1)
attention_weights_idx=softmax(scores_head_idx)
output=torch.matmul(attention_weights_idx, head_idx_value_vector)
outputs.append(output)
[out.shapeforoutinoutputs]
# [torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64]),
# torch.Size([4, 64])]
# stack the result from each heads for the corresponding words
word0_outputs=torch.cat([out[0] foroutinoutputs])
word0_outputs.shape
# lets do it for all the words
attn_outputs= []
foriinrange(len(words)):
attn_output=torch.cat([out[i] foroutinoutputs])
attn_outputs.append(attn_output)
[attn_output.shapeforattn_outputinattn_outputs] # [torch.Size([512]), torch.Size([512]), torch.Size([512]), torch.Size([512])]
# Now lets do it in vectorize way.
# We can not permute the last two dimension of the key vector.
key_vectors_view.permute(0, 1, 3, 2).shape# torch.Size([1, 8, 64, 4])
# Transpose the key vector on the last dim