0
  • 聊天消息
  • 系統(tǒng)消息
  • 評論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

PyTorch教程-11.5。多頭注意力

jf_pJlTbmA9 ? 來源:PyTorch ? 作者:PyTorch ? 2023-06-05 15:44 ? 次閱讀

在實(shí)踐中,給定一組相同的查詢、鍵和值,我們可能希望我們的模型結(jié)合來自同一注意機(jī)制的不同行為的知識,例如捕獲各種范圍的依賴關(guān)系(例如,較短范圍與較長范圍)在一個序列中。因此,這可能是有益的

允許我們的注意力機(jī)制聯(lián)合使用查詢、鍵和值的不同表示子空間。

為此,可以使用以下方式轉(zhuǎn)換查詢、鍵和值,而不是執(zhí)行單個注意力池h獨(dú)立學(xué)習(xí)線性投影。那么這些h投影查詢、鍵和值被并行輸入注意力池。到底,h 注意池的輸出與另一個學(xué)習(xí)的線性投影連接并轉(zhuǎn)換以產(chǎn)生最終輸出。這種設(shè)計稱為多頭注意力,其中每個hattention pooling outputs 是一個頭 (Vaswani et al. , 2017)。使用全連接層執(zhí)行可學(xué)習(xí)的線性變換,圖 11.5.1描述了多頭注意力。

poYBAGR9OBiAZsSgAAEO1TkhU64810.svg

圖 11.5.1多頭注意力,其中多個頭連接起來然后進(jìn)行線性變換。

import math
import torch
from torch import nn
from d2l import torch as d2l

import math
from mxnet import autograd, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l

npx.set_np()

import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

import tensorflow as tf
from d2l import tensorflow as d2l

11.5.1。模型

在提供多頭注意力的實(shí)現(xiàn)之前,讓我們從數(shù)學(xué)上形式化這個模型。給定一個查詢 q∈Rdq, 關(guān)鍵 k∈Rdk和一個值 v∈Rdv, 每個注意力頭 hi(i=1,…,h) 被計算為

(11.5.1)hi=f(Wi(q)q,Wi(k)k,Wi(v)v)∈Rpv,

其中可學(xué)習(xí)參數(shù) Wi(q)∈Rpq×dq, Wi(k)∈Rpk×dk和 Wi(v)∈Rpv×dv, 和f是注意力集中,例如11.3 節(jié)中的附加注意力和縮放點(diǎn)積注意力。多頭注意力輸出是另一種通過可學(xué)習(xí)參數(shù)進(jìn)行的線性變換Wo∈Rpo×hpv的串聯(lián)h負(fù)責(zé)人:

(11.5.2)Wo[h1?hh]∈Rpo.

基于這種設(shè)計,每個頭可能會關(guān)注輸入的不同部分。可以表達(dá)比簡單加權(quán)平均更復(fù)雜的函數(shù)。

11.5.2。執(zhí)行

在我們的實(shí)現(xiàn)中,我們?yōu)槎囝^注意力的每個頭選擇縮放的點(diǎn)積注意力。為了避免計算成本和參數(shù)化成本的顯著增長,我們設(shè)置 pq=pk=pv=po/h. 注意h如果我們將查詢、鍵和值的線性變換的輸出數(shù)量設(shè)置為 pqh=pkh=pvh=po. 在下面的實(shí)現(xiàn)中, po通過參數(shù)指定num_hiddens。

class MultiHeadAttention(d2l.Module): #@save
  """Multi-head attention."""
  def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs):
    super().__init__()
    self.num_heads = num_heads
    self.attention = d2l.DotProductAttention(dropout)
    self.W_q = nn.LazyLinear(num_hiddens, bias=bias)
    self.W_k = nn.LazyLinear(num_hiddens, bias=bias)
    self.W_v = nn.LazyLinear(num_hiddens, bias=bias)
    self.W_o = nn.LazyLinear(num_hiddens, bias=bias)

  def forward(self, queries, keys, values, valid_lens):
    # Shape of queries, keys, or values:
    # (batch_size, no. of queries or key-value pairs, num_hiddens)
    # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
    # After transposing, shape of output queries, keys, or values:
    # (batch_size * num_heads, no. of queries or key-value pairs,
    # num_hiddens / num_heads)
    queries = self.transpose_qkv(self.W_q(queries))
    keys = self.transpose_qkv(self.W_k(keys))
    values = self.transpose_qkv(self.W_v(values))

    if valid_lens is not None:
      # On axis 0, copy the first item (scalar or vector) for num_heads
      # times, then copy the next item, and so on
      valid_lens = torch.repeat_interleave(
        valid_lens, repeats=self.num_heads, dim=0)

    # Shape of output: (batch_size * num_heads, no. of queries,
    # num_hiddens / num_heads)
    output = self.attention(queries, keys, values, valid_lens)
    # Shape of output_concat: (batch_size, no. of queries, num_hiddens)
    output_concat = self.transpose_output(output)
    return self.W_o(output_concat)

class MultiHeadAttention(d2l.Module): #@save
  """Multi-head attention."""
  def __init__(self, num_hiddens, num_heads, dropout, use_bias=False,
         **kwargs):
    super().__init__()
    self.num_heads = num_heads
    self.attention = d2l.DotProductAttention(dropout)
    self.W_q = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
    self.W_k = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
    self.W_v = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
    self.W_o = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)

  def forward(self, queries, keys, values, valid_lens):
    # Shape of queries, keys, or values:
    # (batch_size, no. of queries or key-value pairs, num_hiddens)
    # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
    # After transposing, shape of output queries, keys, or values:
    # (batch_size * num_heads, no. of queries or key-value pairs,
    # num_hiddens / num_heads)
    queries = self.transpose_qkv(self.W_q(queries))
    keys = self.transpose_qkv(self.W_k(keys))
    values = self.transpose_qkv(self.W_v(values))

    if valid_lens is not None:
      # On axis 0, copy the first item (scalar or vector) for num_heads
      # times, then copy the next item, and so on
      valid_lens = valid_lens.repeat(self.num_heads, axis=0)

    # Shape of output: (batch_size * num_heads, no. of queries,
    # num_hiddens / num_heads)
    output = self.attention(queries, keys, values, valid_lens)

    # Shape of output_concat: (batch_size, no. of queries, num_hiddens)
    output_concat = self.transpose_output(output)
    return self.W_o(output_concat)

class MultiHeadAttention(nn.Module): #@save
  num_hiddens: int
  num_heads: int
  dropout: float
  bias: bool = False

  def setup(self):
    self.attention = d2l.DotProductAttention(self.dropout)
    self.W_q = nn.Dense(self.num_hiddens, use_bias=self.bias)
    self.W_k = nn.Dense(self.num_hiddens, use_bias=self.bias)
    self.W_v = nn.Dense(self.num_hiddens, use_bias=self.bias)
    self.W_o = nn.Dense(self.num_hiddens, use_bias=self.bias)

  @nn.compact
  def __call__(self, queries, keys, values, valid_lens, training=False):
    # Shape of queries, keys, or values:
    # (batch_size, no. of queries or key-value pairs, num_hiddens)
    # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
    # After transposing, shape of output queries, keys, or values:
    # (batch_size * num_heads, no. of queries or key-value pairs,
    # num_hiddens / num_heads)
    queries = self.transpose_qkv(self.W_q(queries))
    keys = self.transpose_qkv(self.W_k(keys))
    values = self.transpose_qkv(self.W_v(values))

    if valid_lens is not None:
      # On axis 0, copy the first item (scalar or vector) for num_heads
      # times, then copy the next item, and so on
      valid_lens = jnp.repeat(valid_lens, self.num_heads, axis=0)

    # Shape of output: (batch_size * num_heads, no. of queries,
    # num_hiddens / num_heads)
    output, attention_weights = self.attention(
      queries, keys, values, valid_lens, training=training)
    # Shape of output_concat: (batch_size, no. of queries, num_hiddens)
    output_concat = self.transpose_output(output)
    return self.W_o(output_concat), attention_weights

class MultiHeadAttention(d2l.Module): #@save
  """Multi-head attention."""
  def __init__(self, key_size, query_size, value_size, num_hiddens,
         num_heads, dropout, bias=False, **kwargs):
    super().__init__()
    self.num_heads = num_heads
    self.attention = d2l.DotProductAttention(dropout)
    self.W_q = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
    self.W_k = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
    self.W_v = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
    self.W_o = tf.keras.layers.Dense(num_hiddens, use_bias=bias)

  def call(self, queries, keys, values, valid_lens, **kwargs):
    # Shape of queries, keys, or values:
    # (batch_size, no. of queries or key-value pairs, num_hiddens)
    # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
    # After transposing, shape of output queries, keys, or values:
    # (batch_size * num_heads, no. of queries or key-value pairs,
    # num_hiddens / num_heads)
    queries = self.transpose_qkv(self.W_q(queries))
    keys = self.transpose_qkv(self.W_k(keys))
    values = self.transpose_qkv(self.W_v(values))

    if valid_lens is not None:
      # On axis 0, copy the first item (scalar or vector) for num_heads
      # times, then copy the next item, and so on
      valid_lens = tf.repeat(valid_lens, repeats=self.num_heads, axis=0)

    # Shape of output: (batch_size * num_heads, no. of queries,
    # num_hiddens / num_heads)
    output = self.attention(queries, keys, values, valid_lens, **kwargs)

    # Shape of output_concat: (batch_size, no. of queries, num_hiddens)
    output_concat = self.transpose_output(output)
    return self.W_o(output_concat)

為了允許多個頭的并行計算,上面的 MultiHeadAttention類使用了下面定義的兩種轉(zhuǎn)置方法。具體地,該transpose_output方法將方法的操作反轉(zhuǎn)transpose_qkv。

@d2l.add_to_class(MultiHeadAttention) #@save
def transpose_qkv(self, X):
  """Transposition for parallel computation of multiple attention heads."""
  # Shape of input X: (batch_size, no. of queries or key-value pairs,
  # num_hiddens). Shape of output X: (batch_size, no. of queries or
  # key-value pairs, num_heads, num_hiddens / num_heads)
  X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
  # Shape of output X: (batch_size, num_heads, no. of queries or key-value
  # pairs, num_hiddens / num_heads)
  X = X.permute(0, 2, 1, 3)
  # Shape of output: (batch_size * num_heads, no. of queries or key-value
  # pairs, num_hiddens / num_heads)
  return X.reshape(-1, X.shape[2], X.shape[3])

@d2l.add_to_class(MultiHeadAttention) #@save
def transpose_output(self, X):
  """Reverse the operation of transpose_qkv."""
  X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
  X = X.permute(0, 2, 1, 3)
  return X.reshape(X.shape[0], X.shape[1], -1)

@d2l.add_to_class(MultiHeadAttention) #@save
def transpose_qkv(self, X):
  """Transposition for parallel computation of multiple attention heads."""
  # Shape of input X: (batch_size, no. of queries or key-value pairs,
  # num_hiddens). Shape of output X: (batch_size, no. of queries or
  # key-value pairs, num_heads, num_hiddens / num_heads)
  X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
  # Shape of output X: (batch_size, num_heads, no. of queries or key-value
  # pairs, num_hiddens / num_heads)
  X = X.transpose(0, 2, 1, 3)
  # Shape of output: (batch_size * num_heads, no. of queries or key-value
  # pairs, num_hiddens / num_heads)
  return X.reshape(-1, X.shape[2], X.shape[3])

@d2l.add_to_class(MultiHeadAttention) #@save
def transpose_output(self, X):
  """Reverse the operation of transpose_qkv."""
  X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
  X = X.transpose(0, 2, 1, 3)
  return X.reshape(X.shape[0], X.shape[1], -1)

@d2l.add_to_class(MultiHeadAttention) #@save
def transpose_qkv(self, X):
  """Transposition for parallel computation of multiple attention heads."""
  # Shape of input X: (batch_size, no. of queries or key-value pairs,
  # num_hiddens). Shape of output X: (batch_size, no. of queries or
  # key-value pairs, num_heads, num_hiddens / num_heads)
  X = X.reshape((X.shape[0], X.shape[1], self.num_heads, -1))
  # Shape of output X: (batch_size, num_heads, no. of queries or key-value
  # pairs, num_hiddens / num_heads)
  X = jnp.transpose(X, (0, 2, 1, 3))
  # Shape of output: (batch_size * num_heads, no. of queries or key-value
  # pairs, num_hiddens / num_heads)
  return X.reshape((-1, X.shape[2], X.shape[3]))

@d2l.add_to_class(MultiHeadAttention) #@save
def transpose_output(self, X):
  """Reverse the operation of transpose_qkv."""
  X = X.reshape((-1, self.num_heads, X.shape[1], X.shape[2]))
  X = jnp.transpose(X, (0, 2, 1, 3))
  return X.reshape((X.shape[0], X.shape[1], -1))

@d2l.add_to_class(MultiHeadAttention) #@save
def transpose_qkv(self, X):
  """Transposition for parallel computation of multiple attention heads."""
  # Shape of input X: (batch_size, no. of queries or key-value pairs,
  # num_hiddens). Shape of output X: (batch_size, no. of queries or
  # key-value pairs, num_heads, num_hiddens / num_heads)
  X = tf.reshape(X, shape=(X.shape[0], X.shape[1], self.num_heads, -1))
  # Shape of output X: (batch_size, num_heads, no. of queries or key-value
  # pairs, num_hiddens / num_heads)
  X = tf.transpose(X, perm=(0, 2, 1, 3))
  # Shape of output: (batch_size * num_heads, no. of queries or key-value
  # pairs, num_hiddens / num_heads)
  return tf.reshape(X, shape=(-1, X.shape[2], X.shape[3]))

@d2l.add_to_class(MultiHeadAttention) #@save
def transpose_output(self, X):
  """Reverse the operation of transpose_qkv."""
  X = tf.reshape(X, shape=(-1, self.num_heads, X.shape[1], X.shape[2]))
  X = tf.transpose(X, perm=(0, 2, 1, 3))
  return tf.reshape(X, shape=(X.shape[0], X.shape[1], -1))

讓我們MultiHeadAttention使用一個玩具示例來測試我們實(shí)現(xiàn)的類,其中鍵和值相同。因此,多頭注意力輸出的形狀為 ( batch_size, num_queries, num_hiddens)。

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
d2l.check_shape(attention(X, Y, Y, valid_lens),
        (batch_size, num_queries, num_hiddens))

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
attention.initialize()

batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = np.array([3, 2])
X = np.ones((batch_size, num_queries, num_hiddens))
Y = np.ones((batch_size, num_kvpairs, num_hiddens))
d2l.check_shape(attention(X, Y, Y, valid_lens),
        (batch_size, num_queries, num_hiddens))

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)

batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = jnp.array([3, 2])
X = jnp.ones((batch_size, num_queries, num_hiddens))
Y = jnp.ones((batch_size, num_kvpairs, num_hiddens))
d2l.check_shape(attention.init_with_output(d2l.get_key(), X, Y, Y, valid_lens,
                      training=False)[0][0],
        (batch_size, num_queries, num_hiddens))

num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                num_hiddens, num_heads, 0.5)

batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = tf.constant([3, 2])
X = tf.ones((batch_size, num_queries, num_hiddens))
Y = tf.ones((batch_size, num_kvpairs, num_hiddens))
d2l.check_shape(attention(X, Y, Y, valid_lens, training=False),
        (batch_size, num_queries, num_hiddens))

11.5.3。概括

多頭注意力通過查詢、鍵和值的不同表示子空間結(jié)合相同注意力池的知識。要并行計算多頭注意的多個頭,需要適當(dāng)?shù)膹埩坎僮鳌?/p>

11.5.4。練習(xí)

可視化本實(shí)驗中多個頭的注意力權(quán)重。

假設(shè)我們有一個基于多頭注意力的訓(xùn)練模型,我們想要修剪最不重要的注意力頭以提高預(yù)測速度。我們?nèi)绾卧O(shè)計實(shí)驗來衡量注意力頭的重要性?

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問題,請聯(lián)系本站處理。 舉報投訴
  • pytorch
    +關(guān)注

    關(guān)注

    2

    文章

    794

    瀏覽量

    13010
收藏 人收藏

    評論

    相關(guān)推薦

    基于labview的注意力分配實(shí)驗設(shè)計

    畢設(shè)要求做一個注意力分配實(shí)驗設(shè)計。有些結(jié)構(gòu)完全想不明白。具體如何實(shí)現(xiàn)如下。一個大概5*5的燈組合,要求隨機(jī)亮。兩個聲音大小不同的音頻,要求隨機(jī)響,有大、小兩個選項。以上兩種需要記錄并計算錯誤率。體現(xiàn)在表格上。大家可不可以勞煩幫個忙,幫我構(gòu)思一下, 或者幫我做一下。拜托大家了。
    發(fā)表于 05-07 20:33

    深度分析NLP中的注意力機(jī)制

    注意力機(jī)制越發(fā)頻繁的出現(xiàn)在文獻(xiàn)中,因此對注意力機(jī)制的學(xué)習(xí)、掌握與應(yīng)用顯得十分重要。本文便對注意力機(jī)制做了較為全面的綜述。
    的頭像 發(fā)表于 02-17 09:18 ?3754次閱讀

    融合雙層多頭注意力與CNN的回歸模型

    針對現(xiàn)有文本情感分析方法存在的無法高效捕捉相關(guān)文本情感特征從而造成情感分析效果不佳的問題提出一種融合雙層多頭注意力與卷積神經(jīng)網(wǎng)絡(luò)(CNN)的回歸模型 DLMA-CNN。采用多頭注意力
    發(fā)表于 03-25 15:16 ?6次下載
    融合雙層<b class='flag-5'>多頭</b>自<b class='flag-5'>注意力</b>與CNN的回歸模型

    基于注意力機(jī)制等的社交網(wǎng)絡(luò)熱度預(yù)測模型

    基于注意力機(jī)制等的社交網(wǎng)絡(luò)熱度預(yù)測模型
    發(fā)表于 06-07 15:12 ?14次下載

    基于注意力機(jī)制的跨域服裝檢索方法綜述

    基于注意力機(jī)制的跨域服裝檢索方法綜述
    發(fā)表于 06-27 10:33 ?2次下載

    基于注意力機(jī)制的新聞文本分類模型

    基于注意力機(jī)制的新聞文本分類模型
    發(fā)表于 06-27 15:32 ?30次下載

    基于超大感受野注意力的超分辨率模型

    通過引入像素注意力,PAN在大幅降低參數(shù)量的同時取得了非常優(yōu)秀的性能。相比通道注意力與空域注意力,像素注意力是一種更廣義的注意力形式,為進(jìn)一
    的頭像 發(fā)表于 10-27 13:55 ?1008次閱讀

    如何用番茄鐘提高注意力

    電子發(fā)燒友網(wǎng)站提供《如何用番茄鐘提高注意力.zip》資料免費(fèi)下載
    發(fā)表于 10-28 14:29 ?0次下載
    如何用番茄鐘提高<b class='flag-5'>注意力</b>

    詳解五種即插即用的視覺注意力模塊

    SE注意力模塊的全稱是Squeeze-and-Excitation block、其中Squeeze實(shí)現(xiàn)全局信息嵌入、Excitation實(shí)現(xiàn)自適應(yīng)權(quán)重矯正,合起來就是SE注意力模塊。
    的頭像 發(fā)表于 05-18 10:23 ?2317次閱讀
    詳解五種即插即用的視覺<b class='flag-5'>注意力</b>模塊

    PyTorch教程11.4之Bahdanau注意力機(jī)制

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程11.4之Bahdanau注意力機(jī)制.pdf》資料免費(fèi)下載
    發(fā)表于 06-05 15:11 ?0次下載
    <b class='flag-5'>PyTorch</b>教程11.4之Bahdanau<b class='flag-5'>注意力</b>機(jī)制

    PyTorch教程11.5多頭注意力

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程11.5多頭注意力.pdf》資料免費(fèi)下載
    發(fā)表于 06-05 15:04 ?0次下載
    <b class='flag-5'>PyTorch</b>教程<b class='flag-5'>11.5</b>之<b class='flag-5'>多頭</b><b class='flag-5'>注意力</b>

    PyTorch教程11.6之自注意力和位置編碼

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程11.6之自注意力和位置編碼.pdf》資料免費(fèi)下載
    發(fā)表于 06-05 15:05 ?0次下載
    <b class='flag-5'>PyTorch</b>教程11.6之自<b class='flag-5'>注意力</b>和位置編碼

    PyTorch教程16.5之自然語言推理:使用注意力

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程16.5之自然語言推理:使用注意力.pdf》資料免費(fèi)下載
    發(fā)表于 06-05 10:49 ?0次下載
    <b class='flag-5'>PyTorch</b>教程16.5之自然語言推理:使用<b class='flag-5'>注意力</b>

    PyTorch教程-11.6. 自注意力和位置編碼

    11.6. 自注意力和位置編碼? Colab [火炬]在 Colab 中打開筆記本 Colab [mxnet] Open the notebook in Colab Colab [jax
    的頭像 發(fā)表于 06-05 15:44 ?1095次閱讀
    <b class='flag-5'>PyTorch</b>教程-11.6. 自<b class='flag-5'>注意力</b>和位置編碼

    PyTorch教程-16.5。自然語言推理:使用注意力

    16.5。自然語言推理:使用注意力? Colab [火炬]在 Colab 中打開筆記本 Colab [mxnet] Open the notebook in Colab Colab
    的頭像 發(fā)表于 06-05 15:44 ?482次閱讀
    <b class='flag-5'>PyTorch</b>教程-16.5。自然語言推理:使用<b class='flag-5'>注意力</b>