0
  • 聊天消息
  • 系統(tǒng)消息
  • 評論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

Faster Transformer v1.0源碼詳解

jf_pmFSk4VX ? 來源:后來遇見AI ? 2023-09-08 10:20 ? 次閱讀

寫在前面:本文將對 Nvidia BERT 推理解決方案 Faster Transformer 源碼進(jìn)行深度剖析,詳細(xì)分析作者的優(yōu)化意圖,并對源碼中的加速技巧進(jìn)行介紹,希望對讀者有所幫助。本文源碼解讀的內(nèi)容僅限 Faster Transformer v1.0 版本,更高版本的源碼將在后續(xù)文章中繼續(xù)解讀。

1 Faster Transformer

Transformer 模型最早在 2017 年由谷歌在論文中提出,拋棄了以往深度學(xué)習(xí)任務(wù)里面使用到的 CNN 和 RNN,取而代之的是一種 self-Attention 的結(jié)構(gòu),將 Attention 思想發(fā)揮到了極致,一定程度上解決了 RNN 的長序列信息丟失問題,基本取代了 RNN 的地位。
Transformer 一經(jīng)面世便在 NLP 領(lǐng)域大殺四方,近兩年一些文章開創(chuàng)性地將 Transformer 模型跨領(lǐng)域地引用到了 CV 任務(wù)中,并取得了不錯地成果。這也被許多學(xué)者認(rèn)為是開創(chuàng)了 CV 領(lǐng)域的新時代,甚至可能完全取代傳統(tǒng)的卷積操作。
雖然 Transformer在多種場景下都有優(yōu)秀的表現(xiàn),但是在推理部署階段,其計算性能卻受到了巨大的挑戰(zhàn):以 BERT 為原型的多層 Transformer 模型,其性能常常難以滿足在線業(yè)務(wù)對于低延遲和高吞吐的要求。以 BERT-BASE 為例,超過 90% 的計算時間消耗在 12 層 Transformer 的前向計算上。因此,一個高效的 Transformer 前向計算方案,既可以為在線業(yè)務(wù)帶來降本增效的作用,也有利于以 Transformer 結(jié)構(gòu)為核心的各類網(wǎng)絡(luò)在更多實際工業(yè)場景中落地。
基于上述背景,NVIDIA GPU 計算專家團(tuán)隊針對 Transformer 推理提出了的性能優(yōu)化方案:Faster Transformer。
Faster Transformer 是一個 BERT Transformer 單層前向計算的高效實現(xiàn),其代碼簡潔明了,后續(xù)可以通過簡單修改支持多種 Transformer 結(jié)構(gòu)。目前優(yōu)化集中在編碼器(encoder)的前向計算。底層由 CUDA 和 cuBLAS 實現(xiàn),支持 FP16 和 FP32 兩種計算模式,其中 FP16 可以充分利用 Volta 和 Turing 架構(gòu) GPU 上的 Tensor Core 計算單元。

2 優(yōu)化原理

在深入了解 Faster Transformer 的優(yōu)化原理之前,我們先來了解一下主流深度學(xué)習(xí)框架 Tensorflow 中 Transformer 的實現(xiàn)情況,僅僅以一個基本的激活函數(shù) gelu 為例,這個函數(shù)在框架中是通過 8 個類似 Pow、Add、和 Tanh 等基本 OP 來實現(xiàn)的。也就是說每進(jìn)行一次 gelu 運算要調(diào)用 8 次基本 OP,同時底層也對應(yīng) 8 次 GPU kernel 的調(diào)用,每一次調(diào)用也意味著一次顯存讀寫,先不說 kernel 計算耗時,光是顯存讀寫就已經(jīng)是大量的開銷。如何降低這部分開銷?最直觀的方法就是減少調(diào)用,讓數(shù)據(jù)一直留在顯存甚至寄存器里被訪問,即 OP 融合,一次調(diào)用就實現(xiàn)整個計算邏輯。
出于性能最大化的考慮,在 Faster Transformer 內(nèi)部,Nividia 將除矩陣乘法以外的所有 kernel 都進(jìn)行了盡可能的融合,單層 Transformer 的計算流程如下圖所示:
7d9045fa-4d92-11ee-a25d-92fbcf53809c.png

如圖所示,基于 OP 融合的思想,F(xiàn)aster Transformer 只用了 14 個 kernel 就完成了原來將近 60 個 kernel 的計算邏輯。這其中,8 個 kernel 是通過調(diào)用 cuBLAS 接口計算矩陣乘法(黃色框),其余 6 個是自定義 kernel(藍(lán)色框)。
接下來筆者將沿調(diào)用鏈逐步介紹每個 kernel 的優(yōu)化邏輯。

3 調(diào)用鏈

Faster Transformer v1.0 版本源碼地址如下,有興趣的讀者可以前往閱讀。

https://github.com/NVIDIA/FasterTransformer/tree/v1.0/fastertransformer

通讀源碼后筆者對調(diào)用關(guān)系梳理如下。

BertEncoderTransformer->forward()
    ->OpenMultiHeadAttention->forward()
        ->cublasGemmEx
        ->cublasGemmEx
        ->cublasGemmEx
        ->multiHeadAttr_nofuse_kernelLauncher
            ->add_QKV_bias  (kernel)
            ->cublasGemmStridedBatchedEx
            ->softmax_kernel    (kernel)
            ->cublasGemmStridedBatchedEx
            ->transpose (kernel)
    ->cublasGemmEx
    ->add_bias_input_layernorm_kernelLauncher   (kernel)
    ->cublasGemmEx
    ->add_bias_act_kernelLauncher   (kernel)
    ->cublasGemmEx
    ->add_bias_input_layernorm_kernelLauncher   (kernel)

從調(diào)用鏈也可以看出,總共 14 個步驟,與示意圖一一對應(yīng)。核心邏輯都在兩個類中實現(xiàn):BertEncoderTransformer 和 OpenMultiHeadAttention。

4 OpenMultiHeadAttention

OpenMultiHeadAttention 類中有兩個重要的成員方法:構(gòu)造函數(shù)、forward 方法。其中構(gòu)造函數(shù)內(nèi)主要進(jìn)行一些參數(shù)初始化功能,設(shè)備內(nèi)存的申請和初始化也在該函數(shù)內(nèi)進(jìn)行。forward 方法內(nèi)主要是多頭注意力機(jī)制核心邏輯的具體實現(xiàn)。

4.1 cublasGemmEx for Q、K、V

forward 方法中首先就是對輸入的 3 個 tensor 進(jìn)行線性變換,其實就是對 3 個 tensor 分別進(jìn)行 Dense 層變換,我們知道 Dense 是包含一個矩陣乘法和一個 add_bias 操作,這里只進(jìn)行矩陣乘法,add_bias 操作放在后面的 kernel 進(jìn)行。這里使用了 cuBLAS 接口計算矩陣乘法,具體代碼如下:

check_cuda_error(cublasGemmEx(param_.cublas_handle, 
    CUBLAS_OP_N, CUBLAS_OP_N, 
    n, m, k, 
    &alpha, 
    param_.attr_kernel_Q, AType_, n, 
    param_.from_tensor, BType_, k, 
    &beta, 
    query_buf_, CType_, n, 
    computeType_, 
    static_cast(cublasAlgo_[0])));

check_cuda_error(cublasGemmEx(param_.cublas_handle, 
    CUBLAS_OP_N, CUBLAS_OP_N,
    n, m, k, 
    &alpha, 
    param_.attr_kernel_K, AType_, n, 
    param_.to_tensor, BType_, k, 
    &beta, 
    key_buf_, CType_, n, 
    computeType_, 
    static_cast(cublasAlgo_[0])));

check_cuda_error(cublasGemmEx(param_.cublas_handle, 
    CUBLAS_OP_N, CUBLAS_OP_N, 
    n, m, k,
    &alpha,
    param_.attr_kernel_V, AType_, n, 
    param_.to_tensor, BType_, k, 
    &beta, 
    value_buf_, CType_, n, 
    computeType_, 
    static_cast(cublasAlgo_[0])));

這里僅僅是矩陣乘法 API 的調(diào)用,按文檔傳參即可,這里不展開介紹,筆者計劃另開一篇文章專門介紹這個 API 的調(diào)用方法。

4.2 multiHeadAttr_nofuse_kernelLauncher

4.2.1 add_QKV_bias

上面說過 Dense 層包含矩陣乘法和 add_bias 操作,其中 add_bias 操作在核函數(shù) add_QKV_bias 中完成,源碼針對兩種數(shù)據(jù)類型 fp16 和 fp32 分別提供了一個 kernel,只是網(wǎng)絡(luò)結(jié)構(gòu)有所差異。
針對 fp32,每個 block 處理一個 word,總共有 batch_size * seq_len * 3 個 block,對于 Q、K、V 3 個 tensor 而言,前 batch_size * seq_len 個 block 處理 Q,中間 batch_size * seq_len 個 block 處理 K,后 batch_size * seq_len 個 block 處理 V。示意圖如下:
7dab6a92-4d92-11ee-a25d-92fbcf53809c.png

/**
 * @brief 
 * 
 * @tparam T                          OperationType
 * @param Q                           [batch_size, seq_len, head_num, size_per_head], query
 * @param bias_Q                      [head_num * size_per_head, ] length is the same as word's embedding dim
 * @param K                           [batch_size, seq_len, head_num, size_per_head], key
 * @param bias_K                      [head_num * size_per_head, ] length is the same as word's embedding dim
 * @param V                           [batch_size, seq_len, head_num, size_per_head], value
 * @param bias_V                      [head_num * size_per_head, ] length is the same as word's embedding dim
 * @param q_buf_                      [batch_size, head_num, seq_len, size_per_head], transpose query & add bias
 * @param k_buf_                      [batch_size, head_num, seq_len, size_per_head], transpose key & add bias
 * @param v_buf_                      [batch_size, head_num, seq_len, size_per_head], transpose value & add bias
 * @param batch_size
 * @param seq_len                     
 * @param head_num 
 * @param size_per_head 
 * @param word_per_block              1
 * @return __global__ 
 */
template
__global__
void add_QKV_bias(T* Q, const T* bias_Q, T* K, const T* bias_K, T* V, const T* bias_V, T* q_buf_, T* k_buf_, T* v_buf_, 
  const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int word_per_block)
{

  T* data_ptr;
  T* buf_ptr;
  const T* bias_ptr;
  // word counts per batch
  int m = batch_size * seq_len;
  // word embedding dim
  int n = head_num * size_per_head;
  // 總共有3m個block,第一部分處理q,第二部分處理k,第三部分處理v,這里使用qkv_id區(qū)分處理哪個矩陣
  int qkv_id = blockIdx.x * word_per_block / m;
  // 矩陣偏移量
  int row_offset = (blockIdx.x * word_per_block % m) * n;

  if(qkv_id == 0)
  {
    data_ptr = Q + row_offset;
    buf_ptr = q_buf_;
    bias_ptr = bias_Q;
  }
  else if(qkv_id == 1)
  {
    data_ptr = K + row_offset;
    buf_ptr = k_buf_;
    bias_ptr = bias_K;
  }
  else
  {
    data_ptr = V + row_offset;
    buf_ptr = v_buf_;
    bias_ptr = bias_V;
  }

  int batch_id = (blockIdx.x * word_per_block % m) / seq_len;
  int head_id = threadIdx.x / size_per_head;
  int id_in_head = threadIdx.x % size_per_head;
  // word_id in seq, not data_index
  int word_start_id = (blockIdx.x * word_per_block) % seq_len;

  T bias = __ldg(&bias_ptr[threadIdx.x]);

  for(int i = word_start_id; i < word_start_id + word_per_block; ++i)
  {
    // add bias
    T tmp = data_ptr[threadIdx.x] + bias;
    // buf's shape: [bacth_size, head_num, seq_len, size_per_head]
    int target_id = batch_id * (seq_len * head_num * size_per_head) + head_id * seq_len * size_per_head + 
      i * size_per_head + id_in_head;

    buf_ptr[target_id] = tmp;
    data_ptr += n;
  }
}

核函數(shù)第一部分先根據(jù) block_id 確定當(dāng)前處理的 tensor 具體是 Q、K、V 中的哪一個,從而拿到輸入輸出變量的內(nèi)存地址。
第二部分求出 tensor 中對應(yīng)元素的索引,首先我們知道輸入輸出 tensor 是一個四維的 array,所以應(yīng)該有四個索引,按維度順序依次是 batch_id、word_start_id、head_id、id_in_head,有讀者看到這里可能會有疑問:為什么要計算這些索引,前面計算了矩陣偏移量 row_offset,完全可以在 block 內(nèi)按 thread_id 索引就可以拿到對應(yīng)元素。原因是在 add_QKV_bias 核函數(shù)中計算邏輯不僅僅是 add,還有 transpose,熟悉 multiHeadAttention 的讀者都知道對 Q、K、V 線性映射之后,緊接著就是一個 transpose 操作,目的是把 embedding_dim 這個維度劃分成多個獨立的 head,每個 head 后面單獨進(jìn)行 attention,所以要把 head 維度移到 seq_len 維度前面。換句話說這里的 transpose 解決的是“多頭”的問題,和 attention 無關(guān)。
理解了前面的邏輯,第三部分就比較簡單了,先進(jìn)行 add 操作,然后將結(jié)果按照 [bacth_size, head_num, seq_len, size_per_head] 的維度順序?qū)懺谳敵?tensor 中,這里隱含了一個 transpose,需要注意的是這個 transpose 操作是輸出的 tensor 中元素存儲順序相對于輸入 tensor 而言的,并不是對輸入 tensor 做了變換。

針對 fp16,每個 block 同時處理 Q、K、V 上的同一個 word,同一個線程先后處理 3 個 word 上對應(yīng)元素的計算邏輯,實際計算中把 half 都轉(zhuǎn)成了 half2,使用標(biāo)準(zhǔn)庫中的函數(shù) __hadd2 運算。網(wǎng)絡(luò)結(jié)構(gòu)如下:7dce1de4-4d92-11ee-a25d-92fbcf53809c.png

從圖中可以看出,block_size 為 embedding_dim 的一半,這是因為用了 half2 這個數(shù)據(jù)結(jié)構(gòu),實際上每個線程處理了 2 個元素,所以線程數(shù)量縮減一半。核函數(shù)內(nèi)部邏輯分為 2 個部分:1、求出 tensor 中對應(yīng)元素的索引。2、一次對 Q、K、V 進(jìn)行 add 和 transpose 操作。

template <>
__global__
void add_QKV_bias(__half* Q, const __half* bias_Q, __half* K, const __half* bias_K, __half* V, const __half* bias_V, 
  __half* q_buf_, __half* k_buf_, __half* v_buf_, 
  const int batch_size, const int seq_len, const int head_num, const int size_per_head, const int word_per_block)
{
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  int batch_id = tid / (head_num * seq_len * size_per_head);
  int seq_id = (tid % (head_num * seq_len * size_per_head)) / (head_num * size_per_head);
  int head_id = (tid % (head_num * size_per_head)) / size_per_head;
  // id in head
  int id = tid % size_per_head;
  // from [batch_size, seq_len, head_num, size_per_head] tanspose to [bacth_size, head_num, seq_len, size_per_head]
  int target_id = target_index(batch_id, seq_id, head_id, id, batch_size, seq_len, head_num, size_per_head);

  int bias_id = threadIdx.x;

  half2* src_ptr = (half2*)Q;
  half2* dst_ptr = (half2*)q_buf_;
  const half2* bias_ptr = (const half2*)bias_Q;
  dst_ptr[target_id] = __hadd2(src_ptr[tid],  __ldg(&bias_ptr[bias_id]));

  src_ptr = (half2*)K;
  dst_ptr = (half2*)k_buf_;
  bias_ptr = (const half2*)bias_K;
  dst_ptr[target_id] = __hadd2(src_ptr[tid],  __ldg(&bias_ptr[bias_id]));

  src_ptr = (half2*)V;
  dst_ptr = (half2*)v_buf_;
  bias_ptr = (const half2*)bias_V;
  dst_ptr[target_id] = __hadd2(src_ptr[tid],  __ldg(&bias_ptr[bias_id]));
}

4.2.2 計算 attention scores

先來看一下 attention 的計算公式,定義如下:

其中,,也就是說這一步要解決的是一個矩陣計算,用 tensorflow 代碼表示如下:

scores = tf.matmul(query, key, transpose_b=True)

針對矩陣運算,源碼中直接調(diào)用了 cuBLAS API,具體代碼如下:

DataType_ alpha = (DataType_)1.0f, beta = (DataType_)0.0f;
// 計算 q * k^T
check_cuda_error(cublasGemmStridedBatchedEx(cublas_handle,
    CUBLAS_OP_T, CUBLAS_OP_N,
    seq_len, seq_len, size_per_head,
    &alpha,
    k_buf_, AType_, size_per_head, seq_len * size_per_head,
    q_buf_, BType_, size_per_head, seq_len * size_per_head,
    &beta,
    qk_buf_, CType_, seq_len, seq_len * seq_len,
    batch_size * head_num,
    computeType_,
    static_cast(cublasAlgo_[1])));

不熟悉 attention 的讀者可能會問 attention scores 的具體含義是什么,筆者在早期的文章中有過介紹,其實就是兩個矩陣的詞向量兩兩相乘,向量相乘有什么含義?相似度,這個分?jǐn)?shù)就代表 Q、K 的相似度。有興趣的讀者可以移步筆者之前的文章詳細(xì)了解。(https://mp.weixin.qq.com/s/0zen3ItKmDLt5rTUbF37Mg)

4.2.3 softmax_kernel

拿到 Q、K 的相似度之后,直觀上只要右乘一個 V 就可以得到 attention out,其含義就是一個加權(quán)平均的概念,既然要加權(quán)平均,必然要對權(quán)值進(jìn)行歸一化處理,這里的 softmax 就是這個作用。關(guān)于 softmax 核函數(shù)的實現(xiàn)方法筆者在前兩篇文章也有介紹,OneFlow 官方給出了更為高效的實現(xiàn)方式,其高效的原因主要在訪存帶寬處理上,有興趣的讀者可以移步。【CUDA編程】OneFlow Softmax 算子源碼解讀之WarpSoftmax,【CUDA編程】OneFlow Softmax算子源碼解讀之BlockSoftmax

// 計算softmax(qk)
if(seq_len <= 32)
    block.x = 32;
else if(seq_len > 32 && seq_len <= 64)
    block.x = 64;
else if(seq_len > 64 && seq_len <= 128)
    block.x = 128;
else if(seq_len > 128 && seq_len <= 256)
    block.x = 256;
else if(seq_len > 256 && seq_len <= 512)
    block.x = 512;
else
    block.x = 1024;

if(batch_size * head_num <= 120)
{
    grid.x = batch_size * head_num * seq_len;
    softmax_kernel_v2<<>>(qk_buf_, attr_mask, batch_size, head_num, seq_len, scaler); 
}
else
{
    grid.x = batch_size * head_num;
    softmax_kernel<<>>(qk_buf_, attr_mask, batch_size, head_num, seq_len, scaler); 
}

源碼中核函數(shù)的 block_size 是根據(jù) seq_len 確定的,取大于 seq_len 且為 32 的 的最小值。
另外在調(diào)用 softmax kernel 之前,會根據(jù) batch_size * head_num 選擇不同的 softmax kernel,主要是為了保證在大 batch 的情況下的計算效率,這里以 120 為閾值,應(yīng)該是作者的經(jīng)驗數(shù)值。這里作者給出了 2 個 softmax kernel 的實現(xiàn)。
當(dāng) batch_size * head_num > 120 時,此時 batch 內(nèi)元素較多,grid_size 取 batch_size * head_num,這時一個線程內(nèi)處理一個 seq_len 的數(shù)據(jù)。
7df07434-4d92-11ee-a25d-92fbcf53809c.png

/**
 * @brief 
 * 
 * @tparam T 
 * @param qk_buf_                 [batch_size, head_num, seq_len, seq_len]
 * @param attr_mask               [batch_size, seq_len, seq_len]
 * @param batch_size 
 * @param head_num 
 * @param seq_len 
 * @param scaler                  縮放因子
 * @return __global__ 
 */
template 
__global__
void softmax_kernel(T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num, const int seq_len, 
  const T scaler)
{
    // grid_size = batch_size * head_num
    int batch_id = blockIdx.x / head_num;
    // batch偏移量
    int qk_offset = blockIdx.x * seq_len * seq_len;
    int mask_offset = batch_id * seq_len * seq_len;

    __shared__ float s_sum, s_max;

    // 每次處理一個seq_len的數(shù)據(jù)
    for(int i = 0; i < seq_len; ++i)
    {
      float qk = threadIdx.x < seq_len ? (float)qk_buf_[threadIdx.x + qk_offset] : 0.0f;
      float mask_val = threadIdx.x < seq_len ? (float)attr_mask[threadIdx.x + mask_offset] : 0.0f;
      // 對于某些padded的word,給一個很小的值使其近似達(dá)到不參與運算的目的
      mask_val = (1.0f - mask_val) * -10000.0f;

      float tmp = threadIdx.x < seq_len ? (float)(qk * (float)scaler + mask_val): -1e20f;

      float max_val = blockReduceMax(tmp);

      if(threadIdx.x == 0)
        s_max = max_val;
      __syncthreads();

      qk = threadIdx.x < seq_len ? __expf(tmp - s_max) : 0.0f;

      float sum_val = blockReduceSum(qk);

      if(threadIdx.x == 0)
      {
        s_sum = sum_val + 1e-6f;
      }
      __syncthreads();

      if(threadIdx.x < seq_len)
        qk_buf_[threadIdx.x + qk_offset] = (T)(qk / s_sum);

      qk_offset += seq_len;
      mask_offset += seq_len;
    }
}

核函數(shù)內(nèi)首先計算了每個元素偏移量,對于輸入 tensor 而言,每個 block 處理 seq_len * seq_len 個數(shù)據(jù),所以 block 內(nèi)元素偏移量為 blockIdx.x * seq_len * seq_len,而對于 mask 矩陣而言,其維度為 [batch_size, seq_len, seq_len],跟 head_num 無關(guān),所以其偏移量為 batch_id * seq_len * seq_len。
接下來是一層循環(huán),對于 seq_len * seq_len 矩陣而言,每個線程處理當(dāng)前 thread_id 列的元素,每輪循環(huán)結(jié)束,處理該列下一行的元素。在每一輪循環(huán)中,所有的線程一起處理一行數(shù)據(jù),首先拿到數(shù)據(jù) qk 以及 mask_val。如果 mask_val 為 0,則給 mask_val 賦一個很小的值最后加在 qk 上使 qk 值很小,以致最終這個 softmax 分量趨于 0;如果 mask_val 為 1,則 mask 不干預(yù)后續(xù)計算。每個線程拿到處理后的 qk 值即 tmp 后,進(jìn)行一次塊內(nèi)規(guī)約,即可求出當(dāng)前行的最大值 max_val,然后為了避免指數(shù)運算導(dǎo)致數(shù)值溢出,讓 tmp 減去 max_val 并求其指數(shù)值賦給 qk ,然后對 qk 再一次塊內(nèi)規(guī)約求出當(dāng)前行的和 s_sum,最后讓 qk 除以 s_sum 即可得到 softmax 值。核函數(shù)內(nèi)要注意在兩次塊內(nèi)規(guī)約后一定要進(jìn)行一次塊內(nèi)同步,否則可能計算錯誤。

當(dāng) batch_size * head_num <= 120 時,此時 batch 較小,grid_size 取 batch_size * head_num * seq_len,這時一個線程塊內(nèi)處理一行數(shù)據(jù),每個線程內(nèi)只處理一個的數(shù)據(jù)。

template 
__global__
void softmax_kernel_v2(T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num, 
  const int seq_len, const float scaler)
{
    int batch_id = blockIdx.x / head_num / seq_len;
    int seq_id = blockIdx.x % seq_len;
    int qk_offset = blockIdx.x * seq_len;
    int mask_offset = batch_id * seq_len * seq_len + seq_id * seq_len;

    __shared__ float s_sum, s_max;

    float qk = threadIdx.x < seq_len ? (float)qk_buf_[threadIdx.x + qk_offset] : 0.0f;
    float mask_val = threadIdx.x < seq_len ? (float)attr_mask[threadIdx.x + mask_offset] : 0.0f;
      
    mask_val = (1.0f - mask_val) * -10000.0f;

    float tmp = threadIdx.x < seq_len ? (float)(qk * (float)scaler + mask_val) : -1e20f;
    float max_val = blockReduceMax(tmp);
    if(threadIdx.x == 0)
      s_max = max_val;
    __syncthreads();

    float qk_tmp = threadIdx.x < seq_len ? __expf((float)(tmp - s_max)) : 0.0f;
    float sum_val = blockReduceSum(qk_tmp);

    if(threadIdx.x == 0)
    {
      s_sum = sum_val + 1e-6f;
    }
    __syncthreads();

    if(threadIdx.x < seq_len)
      qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / s_sum);
}

這種情況下不涉及循環(huán)處理,計算邏輯與前面 softmax_kernel 循環(huán)體內(nèi)部計算邏輯相同,不再贅述。

4.2.4 計算多頭 attention out

這一步的意思就是使用 softmax 后的相似度矩陣右乘一個 V,得到多頭注意力輸出,注意這時候輸出 tensor 的維度為 [batch_size, head_num, seq_len, size_per_head]。源碼中直接調(diào)用了 cuBLAS API,具體代碼如下:

// 計算qk * v
check_cuda_error(cublasGemmStridedBatchedEx(cublas_handle,
    CUBLAS_OP_N, CUBLAS_OP_N,
    size_per_head, seq_len, seq_len,
    &alpha,
    v_buf_, AType_, size_per_head, seq_len * size_per_head,
    qk_buf_, BType_, seq_len, seq_len * seq_len,
    &beta,
    transpose_dst_, CType_, size_per_head, seq_len * size_per_head,
    batch_size * head_num,
    computeType_,
    static_cast(cublasAlgo_[2])));

4.2.5 transpose

前面說過,多頭 attention out 的維度為 [batch_size, head_num, seq_len, size_per_head],此時這些 head 已經(jīng)完成使命了,通過獨立的 head_num 組 attention 參數(shù)計算出了 attention out,最后需要做的就是把這 head_num 組 attention out 拼接起來,體現(xiàn)在 tensor 上就是做一次 transpose,將維度變?yōu)?[batch_size, seq_len, head_num, size_per_head]。源碼針對 fp16 和 fp32 分別提供了一個核函數(shù) transpose,計算邏輯和 add_QKV_bias 中 transpose 計算邏輯相同,索引按順序乘即可。具體代碼如下:

template
__global__
void transpose(T* src, T* dst, const int batch_size, const int seq_len, const int head_num, const int size_per_head)
{
  int batch_id = blockIdx.x / (head_num * seq_len);
  int seq_id = blockIdx.x % seq_len;
  int head_id = (blockIdx.x % (head_num * seq_len))/ seq_len;
  dst[batch_id * (head_num * seq_len * size_per_head) + seq_id * head_num * size_per_head
    + head_id * size_per_head + threadIdx.x] = src[blockIdx.x * size_per_head + threadIdx.x];
}

template<>
  __global__
void transpose(__half* src, __half* dst,
    const int batch_size, const int seq_len, const int head_num, const int size_per_head)
{
  int tid = blockIdx.x * blockDim.x + threadIdx.x;

  int batch_id = tid / (head_num * seq_len * size_per_head);
  int head_id = (tid % (head_num * seq_len * size_per_head)) / (seq_len * size_per_head);
  int seq_id = (tid % (seq_len * size_per_head)) / size_per_head;
  int id = tid % size_per_head;

  int target_id = target_index(batch_id, head_id, seq_id, id, batch_size, head_num, seq_len, size_per_head);
  half2* src_ptr = (half2*)src;
  half2* dst_ptr = (half2*)dst;

  dst_ptr[target_id] = src_ptr[tid];
}

5 BertEncoderTransformer

BertEncoderTransformer 類中有兩個重要的成員方法:構(gòu)造函數(shù)、forward 方法。其中構(gòu)造函數(shù)內(nèi)主要進(jìn)行一些參數(shù)初始化功能,設(shè)備內(nèi)存的申請和初始化也在該函數(shù)內(nèi)進(jìn)行。forward 方法內(nèi)主要是核心邏輯的實現(xiàn)。

5.1 attention forward

根據(jù)調(diào)用鏈可知,BertEncoderTransformer->forward() 中第一步就是 attention_->forward(),其中 attention_ 對象在構(gòu)造函數(shù)中被定義,attention_->forward() 執(zhí)行的就是第 4 節(jié)的內(nèi)容。

5.2 對 attention out 做線性變換

根據(jù)流程圖和調(diào)用鏈可知,這一步是對多頭注意力的輸出 tensor 做一層線性變換,右乘一個參數(shù)矩陣,其實就是一個不加激活函數(shù)的 Dense 層,分為矩陣乘法和 add bias 兩個操作步驟,這里調(diào)用了 cuBLAS API 實現(xiàn)矩陣乘法。

DataType_ alpha = (DataType_)1.0f;
DataType_ beta = (DataType_)0.0f;
int m = batch_size_ * from_seq_len_;
int k = head_num_ * size_per_head_;
int n = k;

check_cuda_error(cublasGemmEx(param_.cublas_handle, 
    CUBLAS_OP_N, CUBLAS_OP_N,
    n, m, k, 
    &alpha, 
    param_.attr_output_kernel, AType_, n, 
    attr_out_buf_, BType_, k, 
    &beta, 
    attr_matmul_buf_, CType_, n, 
    computeType_, 
    static_cast(cublasAlgo_[0])));

5.3 add_bias_input_layernorm_kernel

從核函數(shù)名字可以看出,這個核函數(shù)實現(xiàn)了 3 個操作:add bias、add input、layernomalization。其中 add bias 是完成上一步線性變換未完成的加偏置工作,add input 是 transformer 模型中的殘差結(jié)構(gòu),layernomalization 則是層歸一化操作。綜合起來這個核函數(shù)的作用是:對線性變換后的 attention out 加偏置,然后加上原始輸入 tensor 組成一個殘差結(jié)構(gòu),最后進(jìn)行一次層歸一化變換。源碼中針對 fp16 和 fp32 分別提供了一個核函數(shù)實現(xiàn),計算邏輯都一樣,這里只以 fp32 為例介紹。7e1be790-4d92-11ee-a25d-92fbcf53809c.png

/**
 * @brief                       grid_size = m, block_size = n
 * 
 * @tparam T 
 * @param out                   [batch_size, sql_len, latent_dim]
 * @param input                 [batch_size, sql_len, latent_dim]
 * @param bias                  [latent_dim,]
 * @param gamma 
 * @param beta 
 * @param m                     batch_size * seq_len
 * @param n                     latent_dim
 * @return __global__ 
 */
template 
__global__ 
void add_bias_input_layernorm(T* out, const T* input, const T* bias, const T* gamma, const T* beta, int m, int n)
{
  int tid = threadIdx.x;

  __shared__ float s_mean;
  __shared__ float s_variance;
  float mean =  0.0f;
  float variance = 0.0f;

  float local_out = 0.0f;
  // add,一個block處理一行
  for(int i = tid; i < n; i += blockDim.x)
    local_out += (float)(out[blockIdx.x * n + i] + input[blockIdx.x * n + i] + __ldg(&bias[i]));
  // mean_i = sum(x_i[j] for j in range(k)) / k
  mean = blockReduceSum(local_out);
  if(threadIdx.x == 0)
    s_mean = mean / n;
  __syncthreads();
  // var_i = sum((x_i[j] - mean_i) ** 2 for j in range(k)) / k + epsilon
  variance = blockReduceSum((local_out - s_mean) * (local_out - s_mean));
  if(threadIdx.x == 0)
    s_variance = variance / n + 1e-6f;
  __syncthreads();
  // x_i_normalized = (x_i - mean_i) / sqrt(var_i)
  // output_i = x_i_normalized * gamma + beta
  for(int i = tid; i < n; i += blockDim.x)
    out[blockIdx.x * n + i] = 
    (T)(((local_out - s_mean) * rsqrtf(s_variance)) * (float)(__ldg(&gamma[i])) + (float)(__ldg(&beta[i])));
}

如示意圖所示,核函數(shù)中每個 block 處理一行數(shù)據(jù),共 latent_dim = head_num * size_per_head 個元素,核函數(shù)中首先計算了 add bias、add input 兩個操作,并將計算結(jié)果存儲在寄存器變量 local_out 中。
接下來就是標(biāo)準(zhǔn)的 layerNormalization 操作,我們先來看一下 layerNormalization 的操作步驟,參照一下 tensorflow 框架 API 文檔。

For each sample x_i in inputs with k features, we compute the mean and variance of the sample:
mean_i = sum(x_i[j] for j in range(k)) / k
var_i = sum((x_i[j] - mean_i) ** 2 for j in range(k)) / k

and then compute a normalized x_i_normalized, including a small factor epsilon for numerical stability.
x_i_normalized = (x_i - mean_i) / sqrt(var_i + epsilon)

And finally x_i_normalized is linearly transformed by gamma and beta, which are learned parameters:
output_i = x_i_normalized * gamma + beta
gamma and beta will span the axes of inputs specified in axis, and this part of the inputs' shape must be fully defined.

具體地,第一步計算均值和方差,核函數(shù)中使用塊內(nèi)規(guī)約計算出均值 s_mean 存儲在共享內(nèi)存中,所有塊內(nèi)線程都可以訪問。然后根據(jù) s_mean 和線程內(nèi)的 local_out 以及 epsilon 系數(shù)再進(jìn)行一次塊內(nèi)規(guī)約計算出方差 s_variance 存儲在共享內(nèi)存中。
第二步進(jìn)行歸一化和線性變換,對應(yīng) tensorflow API 的二、三步,直接計算即可,沒有其他技巧,公式如下:

5.4 FeedForward 結(jié)構(gòu)

根據(jù) Transformer 模型結(jié)構(gòu),多頭注意力之后為了增強表達(dá)能力,加了一個 FeedForward 層,該結(jié)構(gòu)內(nèi)部就是兩個 Dense 層,第一層 Dense 中使用了激活函數(shù),第二層沒有激活函數(shù)。所以 FeedForward 層中包含了 5 個操作:矩陣乘法、add bias、activation、矩陣乘法、add bias。

5.4.1 attention out * inter kernel

FeedForward 層第一次線性變換會擴(kuò)展 tensor 的最后一個維度的長度,源碼中將 latent_dim(也就是 n)擴(kuò)展為原來的 4 倍,所以這里的 inter kernel 的形狀為 [latent_dim, 4 * latent_dim],矩陣運算后的輸出 tensor 形狀為 [batch_size, seq_len, 4 * latent_dim]。

n *= 4;
check_cuda_error(cublasGemmEx(param_.cublas_handle, 
    CUBLAS_OP_N, CUBLAS_OP_N,
    n, m, k, 
    &alpha, 
    param_.inter_kernel, AType_, n, 
    attr_matmul_buf_, BType_, k, 
    &beta, 
    inter_matmul_buf_, CType_, n, 
    computeType_, 
    static_cast(cublasAlgo_[1])));

5.4.2 add_bias_act_kernel

顧名思義,add_bias_act_kernel 核函數(shù)包含了 add bias 和 activation 兩個操作。源碼中 block_size = n / 4 實際就是 latent_dim,為什么不直接取上一次運算后的矩陣寬度 n = 4 * latent_dim 呢?這里是希望一行元素(4 * latent_dim)能在一個 block 內(nèi)處理,如果 block_size 直接取 n = 4 * latent_dim,可能超過 1024,因此還取 latent_dim,線程內(nèi)循環(huán) 4 次處理即可。同樣,源碼中針對 grid_size 也取了 m / 4,在線程中通過循環(huán)每次跨 m / 4 步長處理 4 行數(shù)據(jù)。

template 
__global__ 
void add_bias_act(T* out, const T* bias, int m, int n)
{
  T val, reg_bias;

  int row_id = blockIdx.x;
  int ite = n / blockDim.x;
  int tid = threadIdx.x;

  for(int i = 0; i < ite; ++i)
  {
    reg_bias = __ldg(&bias[i * blockDim.x + tid]);
    row_id = blockIdx.x;

    while(row_id < m){
      val = out[tid + i * blockDim.x + row_id * n]+ reg_bias;
      out[tid + i * blockDim.x + row_id * n] = gelu(val);
      row_id += gridDim.x;
    }
  }
}

核函數(shù)中先對列進(jìn)行循環(huán),ite = 4,從全局內(nèi)存讀出當(dāng)前列的 bias,然后針對行進(jìn)行循環(huán),步長為 m / 4,循環(huán)體內(nèi)部對當(dāng)前行當(dāng)前列的元素進(jìn)行 add bias 和 gelu 操作,這里gelu 操作是一個簡單的 element-wise 操作,比較簡單不再介紹。
筆者點評:這里筆者私以為沒有必要 grid_size 也取 m / 4,cuda 本身對線程塊的數(shù)量沒有限制,完全可以直接取 m,每次每個線程只處理一行數(shù)據(jù),一方面可以增加并行程度,另一方面代碼可閱讀性也更好。筆者給出代碼如下,親測可用。

dim3 grid(m);
dim3 block(n / 4);
assert(block.x <= 1024);
add_bias_act_v2<<>>(out, bias, m, n);

template 
__global__ 
void add_bias_act_v2(T* out, const T* bias, int m, int n) {
  T val, reg_bias;

  int row_id = blockIdx.x;
  int ite = n / blockDim.x;
  int tid = threadIdx.x;

  for(int i = 0; i < ite; ++i) {
    reg_bias = __ldg(&bias[i * blockDim.x + tid]);
    val = out[tid + i * blockDim.x + row_id * n]+ reg_bias;
    out[tid + i * blockDim.x + row_id * n] = gelu(val);
  }
}

5.4.3 inter out * out kernel

FeedForward 層第二次線性變換將 tensor 的最后一個維度的長度轉(zhuǎn)換為原始大小,源碼中將 n 重新賦值為 latent_dim,所以這里的 out kernel 的形狀為 [4 * latent_dim, latent_dim],矩陣運算后的輸出 tensor 形狀為 [batch_size, seq_len, latent_dim]。

n = k;
k *= 4;
check_cuda_error(cublasGemmEx(param_.cublas_handle, 
    CUBLAS_OP_N, CUBLAS_OP_N,
    n, m, k, 
    &alpha, 
    param_.output_kernel, AType_, n, 
    inter_matmul_buf_, BType_, k, 
    &beta, 
    param_.transformer_out, CType_, n, 
    computeType_, 
    static_cast(cublasAlgo_[2])));

5.5 add_bias_input_layernorm_kernel

這個核函數(shù)的計算邏輯在 5.3 中已經(jīng)介紹過了,包含加偏置項、殘差結(jié)構(gòu)、層歸一化三個操作,不再贅述。

6 小結(jié)

至此,Transformer encoder 前向計算的 14 個操作優(yōu)化技巧已介紹完畢??偨Y(jié)如下:

針對半精度 fp16 的優(yōu)化方面。首先,在 kernel 的實現(xiàn)中,將輸入的 half 指針轉(zhuǎn)成 half2 類型,并使用了 half2 相關(guān)的數(shù)學(xué)函數(shù)。這樣不僅僅可以達(dá)到 2 倍于 half 的訪存帶寬和計算吞吐,還可以極大地減少指令的發(fā)射數(shù)量。其次,在 softmax 以及 layerNormalization 的操作中,為防止求和溢出,將數(shù)據(jù)以 half2 的形式讀入后,會轉(zhuǎn)成 float2 類型,來做求和計算,這里就非常細(xì)致了,盡可能地保障了較高精度,值得學(xué)習(xí)借鑒。

針對訪存帶寬方面,筆者以為除 fp16 以外其它數(shù)據(jù)類型也可以進(jìn)一步優(yōu)化,比如可以自定義 pack 類型進(jìn)行合并讀寫,盡量把帶寬打滿。

針對線程網(wǎng)絡(luò)結(jié)構(gòu)方面,源碼中基本使用一個 block 處理一行數(shù)據(jù)的模式進(jìn)行設(shè)計,這里筆者私以為針對 seq_len 和 latent_dim 已知比較小的情況下(不超過1024),完全可以一個線程束處理一行數(shù)據(jù),束內(nèi)同步的開銷遠(yuǎn)小于塊內(nèi)同步。當(dāng)然,這個要求確實有些苛刻了。

源碼中提供了一個塊內(nèi)規(guī)約的代碼,思路非常好,值得初學(xué) cuda 的讀者品讀。

審核編輯:湯梓紅

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問題,請聯(lián)系本站處理。 舉報投訴
  • NVIDIA
    +關(guān)注

    關(guān)注

    14

    文章

    4793

    瀏覽量

    102421
  • gpu
    gpu
    +關(guān)注

    關(guān)注

    27

    文章

    4590

    瀏覽量

    128132
  • 源碼
    +關(guān)注

    關(guān)注

    8

    文章

    626

    瀏覽量

    28965
  • Transformer
    +關(guān)注

    關(guān)注

    0

    文章

    135

    瀏覽量

    5943

原文標(biāo)題:【CUDA編程】Faster Transformer v1.0 源碼詳解

文章出處:【微信號:GiantPandaCV,微信公眾號:GiantPandaCV】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。

收藏 人收藏

    評論

    相關(guān)推薦

    Faster Transformer v2.1版本源碼解讀

    和優(yōu)化技巧進(jìn)行了深度剖析,有興趣的讀者可以移步——【CUDA編程】Faster Transformer v1.0 源碼詳解。 在 Faste
    的頭像 發(fā)表于 09-19 11:39 ?1163次閱讀
    <b class='flag-5'>Faster</b> <b class='flag-5'>Transformer</b> <b class='flag-5'>v</b>2.1版本<b class='flag-5'>源碼</b>解讀

    Altera viterbi compiler v1.0

    Altera viterbi compiler v1.0 下載
    發(fā)表于 03-23 09:45 ?0次下載

    LeMedia使用教程V1.0

    LeMedia使用教程V1.0,介紹LeMedia如何使用。
    發(fā)表于 02-22 17:29 ?10次下載

    ZYBO入門指導(dǎo)手冊(一)v1.0——Vivado

    ZYBO入門指導(dǎo)手冊(一)v1.0——Vivado
    發(fā)表于 09-27 17:02 ?21次下載

    FPGA II實戰(zhàn)演練V1.0

    FPGA II實戰(zhàn)演練V1.0,感興趣的小伙伴們可以瞧一瞧。
    發(fā)表于 11-17 11:43 ?7次下載

    PCiRCTRL使用說明V1.0

    PCiRCTRL使用說明V1.0,紅外編碼解碼YiRTX02,YiRTX03芯片調(diào)試工具。
    發(fā)表于 01-11 12:38 ?9次下載

    串口調(diào)試VB源代碼V1.0

    串口調(diào)試VB源代碼V1.0
    發(fā)表于 02-07 21:06 ?19次下載

    CPP技術(shù)白皮書V1.0

    CPP技術(shù)白皮書V1.0
    發(fā)表于 09-05 14:36 ?12次下載
    CPP技術(shù)白皮書<b class='flag-5'>V1.0</b>

    KT803C_數(shù)據(jù)手冊_V1.0

    KT803C_數(shù)據(jù)手冊_V1.0
    發(fā)表于 12-04 13:55 ?8次下載

    CC2530底板V1.0(含原路圖)

    CC2530底板V1.0(含原路圖)
    發(fā)表于 01-17 10:25 ?0次下載

    V1.0 ATT7021應(yīng)用說明

    V1.0 ATT7021應(yīng)用說明
    發(fā)表于 06-10 15:36 ?3次下載
    <b class='flag-5'>V1.0</b> ATT7021應(yīng)用說明

    TI MCU SW ICDI DRIVERS v1.0

    TI MCU SW-ICDI-DRIVERS v1.0
    發(fā)表于 10-08 09:36 ?8次下載

    v1.0開發(fā)板資料

    電子發(fā)燒友網(wǎng)站提供《v1.0開發(fā)板資料.zip》資料免費下載
    發(fā)表于 10-09 15:20 ?5次下載
    <b class='flag-5'>v1.0</b>開發(fā)板資料

    OK3588-C_硬件手冊_V1.0

    OK3588-C_硬件手冊_V1.0
    發(fā)表于 12-03 11:55 ?36次下載

    PCB設(shè)計工藝指導(dǎo)手冊(v1.0).zip

    PCB設(shè)計工藝指導(dǎo)手冊(v1.0)
    發(fā)表于 12-30 09:20 ?8次下載