0
  • 聊天消息
  • 系統(tǒng)消息
  • 評(píng)論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫(xiě)文章/發(fā)帖/加入社區(qū)
會(huì)員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識(shí)你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

究竟Self-Attention結(jié)構(gòu)是怎樣的?

WpOh_rgznai100 ? 來(lái)源:lq ? 2019-07-18 14:29 ? 次閱讀

一、Self-Attention概念詳解

了解了模型大致原理,我們可以詳細(xì)的看一下究竟Self-Attention結(jié)構(gòu)是怎樣的。其基本結(jié)構(gòu)如下

對(duì)于self-attention來(lái)講,Q(Query), K(Key), V(Value)三個(gè)矩陣均來(lái)自同一輸入,首先我們要計(jì)算Q與K之間的點(diǎn)乘,然后為了防止其結(jié)果過(guò)大,會(huì)除以一個(gè)尺度標(biāo)度,其中為一個(gè)query和key向量的維度。再利用Softmax操作將其結(jié)果歸一化為概率分布,然后再乘以矩陣V就得到權(quán)重求和的表示。該操作可以表示為

這里可能比較抽象,我們來(lái)看一個(gè)具體的例子(圖片來(lái)源于https://jalammar.github.io/illustrated-transformer/),該博客講解的極其清晰,強(qiáng)烈推薦),假如我們要翻譯一個(gè)詞組Thinking Machines,其中Thinking的輸入的embedding vector用表示,Machines的embedding vector用表示。

當(dāng)我們處理Thinking這個(gè)詞時(shí),我們需要計(jì)算句子中所有詞與它的Attention Score,這就像將當(dāng)前詞作為搜索的query,去和句子中所有詞(包含該詞本身)的key去匹配,看看相關(guān)度有多高。我們用代表Thinking對(duì)應(yīng)的query vector,及分別代表Thinking以及Machines對(duì)應(yīng)的key vector,則計(jì)算Thinking的attention score的時(shí)候我們需要計(jì)算與的點(diǎn)乘,同理,我們計(jì)算Machines的attention score的時(shí)候需要計(jì)算與的點(diǎn)乘。如上圖中所示我們分別得到了與的點(diǎn)乘積,然后我們進(jìn)行尺度縮放與softmax歸一化,如下圖所示:

顯然,當(dāng)前單詞與其自身的attention score一般最大,其他單詞根據(jù)與當(dāng)前單詞重要程度有相應(yīng)的score。然后我們?cè)谟眠@些attention score與value vector相乘,得到加權(quán)的向量。

如果將輸入的所有向量合并為矩陣形式,則所有query, key, value向量也可以合并為矩陣形式表示:

其中是我們模型訓(xùn)練過(guò)程學(xué)習(xí)到的合適的參數(shù)。上述操作即可簡(jiǎn)化為矩陣形式:

二、Self_Attention模型搭建

筆者使用Keras來(lái)實(shí)現(xiàn)對(duì)于Self_Attention模型的搭建,由于網(wǎng)絡(luò)中間參數(shù)量比較多,這里采用自定義網(wǎng)絡(luò)層的方法構(gòu)建Self_Attention。

Keras實(shí)現(xiàn)自定義網(wǎng)絡(luò)層。需要實(shí)現(xiàn)以下三個(gè)方法:(注意input_shape是包含batch_size項(xiàng)的)

build(input_shape): 這是你定義權(quán)重的地方。這個(gè)方法必須設(shè)self.built = True,可以通過(guò)調(diào)用super([Layer], self).build()完成。

call(x): 這里是編寫(xiě)層的功能邏輯的地方。你只需要關(guān)注傳入call的第一個(gè)參數(shù):輸入張量,除非你希望你的層支持masking。

compute_output_shape(input_shape): 如果你的層更改了輸入張量的形狀,你應(yīng)該在這里定義形狀變化的邏輯,這讓Keras能夠自動(dòng)推斷各層的形狀。

實(shí)現(xiàn)代碼如下:

from keras.preprocessing import sequencefrom keras.datasets import imdbfrom matplotlib import pyplot as pltimport pandas as pdfrom keras import backend as Kfrom keras.engine.topology import Layerclass Self_Attention(Layer): def __init__(self, output_dim, **kwargs): self.output_dim = output_dim super(Self_Attention, self).__init__(**kwargs) def build(self, input_shape): # 為該層創(chuàng)建一個(gè)可訓(xùn)練的權(quán)重 #inputs.shape = (batch_size, time_steps, seq_len) self.kernel = self.add_weight(name='kernel', shape=(3,input_shape[2], self.output_dim), initializer='uniform', trainable=True) super(Self_Attention, self).build(input_shape) # 一定要在最后調(diào)用它 def call(self, x): WQ = K.dot(x, self.kernel[0]) WK = K.dot(x, self.kernel[1]) WV = K.dot(x, self.kernel[2]) print("WQ.shape",WQ.shape) print("K.permute_dimensions(WK, [0, 2, 1]).shape",K.permute_dimensions(WK, [0, 2, 1]).shape) QK = K.batch_dot(WQ,K.permute_dimensions(WK, [0, 2, 1])) QK = QK / (64**0.5) QK = K.softmax(QK) print("QK.shape",QK.shape) V = K.batch_dot(QK,WV) return V def compute_output_shape(self, input_shape): return (input_shape[0],input_shape[1],self.output_dim)

這里可以對(duì)照一中的概念講解來(lái)理解代碼

如果將輸入的所有向量合并為矩陣形式,則所有query, key, value向量也可以合并為矩陣形式表示

上述內(nèi)容對(duì)應(yīng)

WQ = K.dot(x, self.kernel[0])WK = K.dot(x, self.kernel[1])WV = K.dot(x, self.kernel[2])

其中是我們模型訓(xùn)練過(guò)程學(xué)習(xí)到的合適的參數(shù)。上述操作即可簡(jiǎn)化為矩陣形式:

上述內(nèi)容對(duì)應(yīng)(為什么使用batch_dot呢?這是由于input_shape是包含batch_size項(xiàng)的)

QK = K.batch_dot(WQ,K.permute_dimensions(WK, [0, 2, 1]))QK = QK / (64**0.5)QK = K.softmax(QK)print("QK.shape",QK.shape)V = K.batch_dot(QK,WV)

這里QK = QK / (64**0.5) 是除以一個(gè)歸一化系數(shù),(64**0.5)是筆者自己定義的,其他文章可能會(huì)采用不同的方法。

三、訓(xùn)練網(wǎng)絡(luò)

項(xiàng)目完整代碼如下,這里使用的是Keras自帶的imdb影評(píng)數(shù)據(jù)集。

#%%from keras.preprocessing import sequencefrom keras.datasets import imdbfrom matplotlib import pyplot as pltimport pandas as pdfrom keras import backend as Kfrom keras.engine.topology import Layerclass Self_Attention(Layer): def __init__(self, output_dim, **kwargs): self.output_dim = output_dim super(Self_Attention, self).__init__(**kwargs) def build(self, input_shape): # 為該層創(chuàng)建一個(gè)可訓(xùn)練的權(quán)重 #inputs.shape = (batch_size, time_steps, seq_len) self.kernel = self.add_weight(name='kernel', shape=(3,input_shape[2], self.output_dim), initializer='uniform', trainable=True) super(Self_Attention, self).build(input_shape) # 一定要在最后調(diào)用它 def call(self, x): WQ = K.dot(x, self.kernel[0]) WK = K.dot(x, self.kernel[1]) WV = K.dot(x, self.kernel[2]) print("WQ.shape",WQ.shape) print("K.permute_dimensions(WK, [0, 2, 1]).shape",K.permute_dimensions(WK, [0, 2, 1]).shape) QK = K.batch_dot(WQ,K.permute_dimensions(WK, [0, 2, 1])) QK = QK / (64**0.5) QK = K.softmax(QK) print("QK.shape",QK.shape) V = K.batch_dot(QK,WV) return V def compute_output_shape(self, input_shape): return (input_shape[0],input_shape[1],self.output_dim)max_features = 20000print('Loading data...')(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)#標(biāo)簽轉(zhuǎn)換為獨(dú)熱碼y_train, y_test = pd.get_dummies(y_train),pd.get_dummies(y_test)print(len(x_train), 'train sequences')print(len(x_test), 'test sequences')#%%數(shù)據(jù)歸一化處理maxlen = 64print('Pad sequences (samples x time)')x_train = sequence.pad_sequences(x_train, maxlen=maxlen)x_test = sequence.pad_sequences(x_test, maxlen=maxlen)print('x_train shape:', x_train.shape)print('x_test shape:', x_test.shape)#%%batch_size = 32from keras.models import Modelfrom keras.optimizers import SGD,Adamfrom keras.layers import *from Attention_keras import Attention,Position_EmbeddingS_inputs = Input(shape=(64,), dtype='int32')embeddings = Embedding(max_features, 128)(S_inputs)O_seq = Self_Attention(128)(embeddings)O_seq = GlobalAveragePooling1D()(O_seq)O_seq = Dropout(0.5)(O_seq)outputs = Dense(2, activation='softmax')(O_seq)model = Model(inputs=S_inputs, outputs=outputs)print(model.summary())# try using different optimizers and different optimizer configsopt = Adam(lr=0.0002,decay=0.00001)loss = 'categorical_crossentropy'model.compile(loss=loss, optimizer=opt, metrics=['accuracy'])#%%print('Train...')h = model.fit(x_train, y_train, batch_size=batch_size, epochs=5, validation_data=(x_test, y_test))plt.plot(h.history["loss"],label="train_loss")plt.plot(h.history["val_loss"],label="val_loss")plt.plot(h.history["acc"],label="train_acc")plt.plot(h.history["val_acc"],label="val_acc")plt.legend()plt.show()#model.save("imdb.h5")

四、結(jié)果輸出

(TF_GPU) D:FilesDATAsprjspython f_keras ransfromerdemo>C:/Files/APPs/RuanJian/Miniconda3/envs/TF_GPU/python.exe d:/Files/DATAs/prjs/python/tf_keras/transfromerdemo/train.1.pyUsing TensorFlow backend.Loading data...25000 train sequences25000 test sequencesPad sequences (samples x time)x_train shape: (25000, 64)x_test shape: (25000, 64)WQ.shape (?, 64, 128)K.permute_dimensions(WK, [0, 2, 1]).shape (?, 128, 64)QK.shape (?, 64, 64)_________________________________________________________________Layer (type) Output Shape Param #=================================================================input_1 (InputLayer) (None, 64) 0_________________________________________________________________embedding_1 (Embedding) (None, 64, 128) 2560000_________________________________________________________________self__attention_1 (Self_Atte (None, 64, 128) 49152_________________________________________________________________global_average_pooling1d_1 ( (None, 128) 0_________________________________________________________________dropout_1 (Dropout) (None, 128) 0_________________________________________________________________dense_1 (Dense) (None, 2) 258=================================================================Total params: 2,609,410Trainable params: 2,609,410Non-trainable params: 0_________________________________________________________________NoneTrain...Train on 25000 samples, validate on 25000 samplesEpoch 1/525000/25000 [==============================] - 17s 693us/step - loss: 0.5244 - acc: 0.7514 - val_loss: 0.3834 - val_acc: 0.8278Epoch 2/525000/25000 [==============================] - 15s 615us/step - loss: 0.3257 - acc: 0.8593 - val_loss: 0.3689 - val_acc: 0.8368Epoch 3/525000/25000 [==============================] - 15s 614us/step - loss: 0.2602 - acc: 0.8942 - val_loss: 0.3909 - val_acc: 0.8303Epoch 4/525000/25000 [==============================] - 15s 618us/step - loss: 0.2078 - acc: 0.9179 - val_loss: 0.4482 - val_acc: 0.8215Epoch 5/525000/25000 [==============================] - 15s 619us/step - loss: 0.1639 - acc: 0.9368 - val_loss: 0.5313 - val_acc: 0.8106

聲明:本文內(nèi)容及配圖由入駐作者撰寫(xiě)或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場(chǎng)。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問(wèn)題,請(qǐng)聯(lián)系本站處理。 舉報(bào)投訴
  • 計(jì)算
    +關(guān)注

    關(guān)注

    2

    文章

    437

    瀏覽量

    38605
  • 矩陣
    +關(guān)注

    關(guān)注

    0

    文章

    417

    瀏覽量

    34413
  • 機(jī)制
    +關(guān)注

    關(guān)注

    0

    文章

    24

    瀏覽量

    9756

原文標(biāo)題:機(jī)器如何讀懂人心:Keras實(shí)現(xiàn)Self-Attention文本分類

文章出處:【微信號(hào):rgznai100,微信公眾號(hào):rgznai100】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。

收藏 人收藏

    評(píng)論

    相關(guān)推薦

    面向列的HBase存儲(chǔ)結(jié)構(gòu)究竟有什么樣的不同之處呢?

    HBase是什么?HBase的存儲(chǔ)結(jié)構(gòu)究竟怎樣的呢?面向列的HBase存儲(chǔ)結(jié)構(gòu)究竟有什么樣的不同之處呢?
    發(fā)表于 06-16 06:52

    暴風(fēng)電視拆機(jī)圖解 內(nèi)部結(jié)構(gòu)究竟怎樣

    繼樂(lè)視、小米之后,國(guó)內(nèi)另一家互聯(lián)網(wǎng)公司--暴風(fēng)影音也發(fā)布了旗下的智能電視產(chǎn)品,而且同樣是分體式設(shè)計(jì)。而今天我們就來(lái)看看愛(ài)玩客帶來(lái)的暴風(fēng)電視的內(nèi)部結(jié)構(gòu)究竟怎樣的。
    的頭像 發(fā)表于 09-04 14:32 ?3.1w次閱讀

    BERT模型的PyTorch實(shí)現(xiàn)

    BertModel是一個(gè)基本的BERT Transformer模型,包含一個(gè)summed token、位置和序列嵌入層,然后是一系列相同的self-attention blocks(BERT-base是12個(gè)blocks, BERT-large是24個(gè)blocks)。
    的頭像 發(fā)表于 11-13 09:12 ?1.4w次閱讀

    為什么要有attention機(jī)制,Attention原理

    沒(méi)有attention機(jī)制的encoder-decoder結(jié)構(gòu)通常把encoder的最后一個(gè)狀態(tài)作為decoder的輸入(可能作為初始化,也可能作為每一時(shí)刻的輸入),但是encoder的state
    的頭像 發(fā)表于 03-06 14:11 ?1.7w次閱讀
    為什么要有<b class='flag-5'>attention</b>機(jī)制,<b class='flag-5'>Attention</b>原理

    AAAI 2019 Gaussian Transformer 一種自然語(yǔ)言推理方法

    自然語(yǔ)言推理 (Natural Language Inference, NLI) 是一個(gè)活躍的研究領(lǐng)域,許多基于循環(huán)神經(jīng)網(wǎng)絡(luò)(RNNs),卷積神經(jīng)網(wǎng)絡(luò)(CNNs),self-attention 網(wǎng)絡(luò) (SANs) 的模型為此提出。
    的頭像 發(fā)表于 05-14 09:45 ?3003次閱讀
    AAAI 2019 Gaussian Transformer 一種自然語(yǔ)言推理方法

    解析Transformer中的位置編碼 -- ICLR 2021

    引言 Transformer是近年來(lái)非常流行的處理序列到序列問(wèn)題的架構(gòu),其self-attention機(jī)制允許了長(zhǎng)距離的詞直接聯(lián)系,可以使模型更容易學(xué)習(xí)序列的長(zhǎng)距離依賴。由于其優(yōu)良的可并行性以及可觀
    的頭像 發(fā)表于 04-01 16:07 ?1.2w次閱讀
    解析Transformer中的位置編碼 -- ICLR 2021

    一個(gè)LSTM被分解成垂直和水平的LSTM

    Vision Transformer成功的原因被認(rèn)為是由于Self-Attention建模遠(yuǎn)程依賴的能力。然而,Self-Attention對(duì)于Transformer執(zhí)行視覺(jué)任務(wù)的有效性有多重要還不清楚。事實(shí)上,只基于多層感知器(MLPs)的MLP-Mixer被提議作為V
    的頭像 發(fā)表于 05-07 16:29 ?1291次閱讀

    全球首個(gè)面向遙感任務(wù)設(shè)計(jì)的億級(jí)視覺(jué)Transformer大模型

    Attention, RVSA)來(lái)代替Transformer中的原始完全注意力(Vanilla Full Self-Attention),它可以從生成的不同窗口中提取豐富的上下文信息來(lái)學(xué)習(xí)更好的目標(biāo)表征,并顯著降低計(jì)算成本和內(nèi)存占用。
    的頭像 發(fā)表于 12-09 14:53 ?655次閱讀

    基于視覺(jué)transformer的高效時(shí)空特征學(xué)習(xí)算法

    視覺(jué)Transofrmer通常將圖像分割為不重疊的塊(patch),patch之間通過(guò)自注意力機(jī)制(Self-Attention)進(jìn)行特征聚合,patch內(nèi)部通過(guò)全連接層(FFN)進(jìn)行特征映射。每個(gè)
    的頭像 發(fā)表于 12-12 15:01 ?1331次閱讀

    簡(jiǎn)述深度學(xué)習(xí)中的Attention機(jī)制

    Attention機(jī)制在深度學(xué)習(xí)中得到了廣泛的應(yīng)用,本文通過(guò)公式及圖片詳細(xì)講解attention機(jī)制的計(jì)算過(guò)程及意義,首先從最早引入attention到機(jī)器翻譯任務(wù)(Bahdanau et al. ICLR2014)的方法講起。
    的頭像 發(fā)表于 02-22 14:21 ?1472次閱讀
    簡(jiǎn)述深度學(xué)習(xí)中的<b class='flag-5'>Attention</b>機(jī)制

    解析ChatGPT背后的技術(shù)演進(jìn)

    ?! ?)Transformer模型沒(méi)有使用傳統(tǒng)的CNN和RNN結(jié)構(gòu),其完全是由Attention機(jī)制組成,其中Self-Attention(自注意力)是Transformer的核心?! ?)OpenAI的GPT模型和Googl
    發(fā)表于 03-29 16:57 ?1次下載

    如何入門(mén)面向自動(dòng)駕駛領(lǐng)域的視覺(jué)Transformer?

    理解Transformer背后的理論基礎(chǔ),比如自注意力機(jī)制(self-attention), 位置編碼(positional embedding),目標(biāo)查詢(object query)等等,網(wǎng)上的資料比較雜亂,不夠系統(tǒng),難以通過(guò)自學(xué)做到深入理解并融會(huì)貫通。
    的頭像 發(fā)表于 07-09 14:35 ?494次閱讀
    如何入門(mén)面向自動(dòng)駕駛領(lǐng)域的視覺(jué)Transformer?

    基于Transformer的目標(biāo)檢測(cè)算法的3個(gè)難點(diǎn)

    理解Transformer背后的理論基礎(chǔ),比如自注意力機(jī)制(self-attention), 位置編碼(positional embedding),目標(biāo)查詢(object query)等等,網(wǎng)上的資料比較雜亂,不夠系統(tǒng),難以通過(guò)自學(xué)做到深入理解并融會(huì)貫通。
    發(fā)表于 07-18 12:54 ?587次閱讀
    基于Transformer的目標(biāo)檢測(cè)算法的3個(gè)難點(diǎn)

    基于Transformer的目標(biāo)檢測(cè)算法難點(diǎn)

    理解Transformer背后的理論基礎(chǔ),比如自注意力機(jī)制(self-attention), 位置編碼(positional embedding),目標(biāo)查詢(object query)等等,網(wǎng)上的資料比較雜亂,不夠系統(tǒng),難以通過(guò)自學(xué)做到深入理解并融會(huì)貫通。
    發(fā)表于 08-24 11:19 ?258次閱讀
    基于Transformer的目標(biāo)檢測(cè)算法難點(diǎn)

    視覺(jué)Transformer基本原理及目標(biāo)檢測(cè)應(yīng)用

    視覺(jué)Transformer的一般結(jié)構(gòu)如圖2所示,包括編碼器和解碼器兩部分,其中編碼器每一層包括一個(gè)多頭自注意力模塊(self-attention)和一個(gè)位置前饋神經(jīng)網(wǎng)絡(luò)(FFN)。
    發(fā)表于 04-03 10:32 ?2644次閱讀
    視覺(jué)Transformer基本原理及目標(biāo)檢測(cè)應(yīng)用