0
  • 聊天消息
  • 系統(tǒng)消息
  • 評(píng)論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線(xiàn)課程
  • 觀(guān)看技術(shù)視頻
  • 寫(xiě)文章/發(fā)帖/加入社區(qū)
會(huì)員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識(shí)你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

OneFlow Softmax算子源碼解讀之WarpSoftmax

jf_pmFSk4VX ? 來(lái)源:后來(lái)遇見(jiàn)AI ? 作者:后來(lái)遇見(jiàn)AI ? 2024-01-08 09:24 ? 次閱讀

寫(xiě)在前面:近來(lái)筆者偶然間接觸了一個(gè)深度學(xué)習(xí)框架 OneFlow,所以這段時(shí)間主要在閱讀 OneFlow 框架的 cuda 源碼。官方源碼基于不同場(chǎng)景分三種方式實(shí)現(xiàn) Softmax,本文主要介紹其中一種的實(shí)現(xiàn)過(guò)程,即 Warp 級(jí)別 Softmax,適用于矩陣寬度不超過(guò) 1024 的情況。

1 Softmax

Softmax 操作是深度學(xué)習(xí)模型中最常用的操作之一。在深度學(xué)習(xí)的多分類(lèi)任務(wù)中,最后一層通常是一個(gè) Softmax 操作將 logits 映射成概率,然后結(jié)合交叉熵求損失。另外還有一些場(chǎng)景會(huì)用到 Softmax 做一個(gè)歸一化操作,比如 Transformer 結(jié)構(gòu)中 query 和 key 矩陣相乘并縮放后會(huì)執(zhí)行一個(gè) Softmax 操作,這一步的意義是求出 query 和 key 中每一項(xiàng)的兩兩相似度,具體筆者在另一篇文章有詳述——【ASR】基于DFCNN-CTC模型的語(yǔ)音識(shí)別系統(tǒng)(二)

59c99cac-ad3c-11ee-8b88-92fbcf53809c.png

圖1 Scaled Dot-Product Attention 結(jié)構(gòu)示意圖

深度學(xué)習(xí)框架中的所有算子底層都對(duì)應(yīng)著 GPU上的 CUDA kernel function,Softmax 操作也不例外。Softmax 作為一個(gè)被廣泛使用的算子,其 CUDA Kernel 的實(shí)現(xiàn)會(huì)影響很多網(wǎng)絡(luò)最終的訓(xùn)練速度。那么如何實(shí)現(xiàn)一個(gè)高效的 Softmax CUDA Kernel?本文將會(huì)介紹 OneFlow 中優(yōu)化的 Softmax CUDA Kernel 的技巧,在這之前我們先來(lái)看一下 Softmax 的計(jì)算公式。
定義 x 是一個(gè) n 維向量,其 Softmax 輸出 y 也是一個(gè) n 維向量,那么有如下計(jì)算公式:

從上面的公式可以發(fā)現(xiàn)一個(gè)問(wèn)題,當(dāng) 為一個(gè)較大的正數(shù)時(shí),取指數(shù)后 將會(huì)非常大,從而導(dǎo)致數(shù)值溢出,如何解決這個(gè)問(wèn)題呢?
一般的處理方法是,讓每個(gè)分量去減掉向量的最大值,這樣可以保證取指數(shù)后的結(jié)果必然在 0~1 之間,可以有效避免數(shù)值溢出。處理后的公式如下:

根據(jù)公式可以看出,要執(zhí)行 Softmax 計(jì)算,需要實(shí)現(xiàn) 5 個(gè)業(yè)務(wù)邏輯:reduceMax、broadcastSub、exp、reduceSum、broadcastDiv。下面筆者將對(duì)源碼中的計(jì)算技巧進(jìn)行解讀,有興趣的讀者可以下載源碼來(lái)閱讀(https://github.com/BBuf/how-to-optim-algorithm-in-cuda/blob/master/softmax/oneflow_softmax.cu)。

2 三種實(shí)現(xiàn)方式

Softmax 函數(shù)的輸入形狀為:(num_rows, num_cols),num_cols 的變化會(huì)對(duì)有效帶寬產(chǎn)生影響。因?yàn)椋瑳](méi)有一種通用的優(yōu)化方法可以實(shí)現(xiàn)在所有 num_cols 的情況下都是傳輸最優(yōu)的。所以,在 OneFlow 中采用分段函數(shù)優(yōu)化 SoftmaxKernel:對(duì)于不同 num_cols 范圍,選擇不同的實(shí)現(xiàn),以期在所有情況下都能達(dá)到較高的有效帶寬。
針對(duì)不同的 Softmax 場(chǎng)景,OneFlow 提供了三種實(shí)現(xiàn),分段對(duì) Softmax kernel 進(jìn)行優(yōu)化:

(1) 一個(gè) Warp 處理一行的計(jì)算,適用于 num_cols <= 1024 情況

(2) 一個(gè) Block 處理一行的計(jì)算,借助 Shared Memory 保存中間結(jié)果數(shù)據(jù),適用于需要的 Shared Memory 資源滿(mǎn)足 Kernel Launch 的可啟動(dòng)條件的情況,在本測(cè)試環(huán)境中是 1024 < num_cols <= 4096。

(3) 一個(gè) Block 處理一行的計(jì)算,不使用 Shared Memory,重復(fù)讀輸入 x,適用于不支持(1)、(2)的情況。

分段處理邏輯在 DispatchSoftmax 函數(shù)中體現(xiàn),主體代碼如下:

if (cols < 1024) {
  return DispatchSoftmaxWarpImpl(
      stream, load, store, rows, cols);
} else {
  bool dispatch_smem_impl_success;
  {
    cudaError_t err =
        TryDispatchSoftmaxBlockSMemImpl(
            stream, load, store, rows, cols, &dispatch_smem_impl_success);
    if (err != cudaSuccess) { return err; }
  }
  if (!dispatch_smem_impl_success) {
    return DispatchSoftmaxBlockUncachedImpl(
        stream, load, store, rows, cols);
  }
  return cudaSuccess;
}
,>,>,>

3 WarpSoftmax

3.1 數(shù)據(jù) Pack 提升訪(fǎng)問(wèn)帶寬

在筆者上一篇文章【CUDA編程】OneFlow Element-Wise 算子源碼解讀中詳細(xì)地介紹了如何進(jìn)行向量化讀寫(xiě),有興趣的讀者可以移步,這里我們先看源碼。

template
struct GetPackType {
  using type = typename std::aligned_storage::type;
};

template
using PackType = typename GetPackType::type;

template
union Pack {
  static_assert(sizeof(PackType) == sizeof(T) * N, "");
  __device__  Pack() {
    // do nothing
  }
  PackType storage;
  T elem[N];
};
,>,>,>

oneflow 利用 union 共享空間的特性實(shí)現(xiàn)了一個(gè) Pack 類(lèi)型,細(xì)心的讀者可能會(huì)發(fā)現(xiàn),跟 elementwise.cuh 源碼相比,這里少了一個(gè) Packed 類(lèi),這是因?yàn)?elementwise.cuh 實(shí)現(xiàn)的時(shí)間晚于 softmax.cuh。可能考慮到 Pack 后類(lèi)型的內(nèi)存對(duì)齊特性,重新定義了 Packed 類(lèi),并聲明了內(nèi)存對(duì)齊值為 pack_size * sizeof(T)。
接下來(lái)定義了兩個(gè)代表輸入和輸出的數(shù)據(jù)結(jié)構(gòu) DirectLoad 和 DirectStore,分別實(shí)現(xiàn)了 load 和 store 兩個(gè)函數(shù)用來(lái)把讀取和寫(xiě)入一個(gè) pack 的數(shù)據(jù)。使用 DirectLoad 和 DirectStore 有兩個(gè)好處:

可以在CUDA Kernel中只關(guān)心計(jì)算類(lèi)型ComputeType,而不用關(guān)心具體的數(shù)據(jù)類(lèi)型T。

只需要加幾行代碼就可以快速支持 Softmax 和其他 Kernel Fuse,減少帶寬需求,提升整體性能。

/**
 * @brief 定義了輸入的數(shù)據(jù)結(jié)構(gòu)
 * 
 * @tparam SRC 輸入數(shù)據(jù)的類(lèi)型
 * @tparam DST 計(jì)算數(shù)據(jù)的類(lèi)型,ComputeType
 */
template
struct DirectLoad {
  /**
   * @brief Construct a new Direct Load object
   * 
   * @param src 輸入的數(shù)據(jù)源
   * @param row_size num of elements per row
   */
  DirectLoad(const SRC* src, int64_t row_size) : src(src), row_size(row_size) {}
  /**
   * @brief 從數(shù)據(jù)源 load 一個(gè) pack 數(shù)據(jù)到 dst 
   * 
   * @tparam N pack_size
   * @param dst 
   * @param row 數(shù)據(jù)源的第 row 行
   * @param col 數(shù)據(jù)源的第 col 列
   * @return __device__ 
   */
  template
  __device__ void load(DST* dst, int64_t row, int64_t col) const {
    Pack pack;
    const int64_t offset = (row * row_size + col) / N;  // pack 偏移量
    pack.storage = *(reinterpret_cast*>(src) + offset);
    #pragma unroll
    for (int i = 0; i < N; ++i) { dst[i] = static_cast(pack.elem[i]); }
  }
  const SRC* src;
  int64_t row_size;
};

template
struct DirectStore {
  DirectStore(DST* dst, int64_t row_size) : dst(dst), row_size(row_size) {}
  template
  __device__ void store(const SRC* src, int64_t row, int64_t col) {
    Pack pack;
    const int64_t offset = (row * row_size + col) / N;
#pragma unroll
    for (int i = 0; i < N; ++i) { pack.elem[i] = static_cast(src[i]); }
    *(reinterpret_cast*>(dst) + offset) = pack.storage;
  }
  DST* dst;
  int64_t row_size;
};
,>,>

3.2 調(diào)用鏈

針對(duì) WarpSoftmax 這個(gè)分支,對(duì)源碼中函數(shù)的調(diào)用關(guān)系梳理后如下:

DispatchSoftmaxWarpImpl
  ->DispatchSoftmaxWarpImplPackSize
    ->DispatchSoftmaxWarpImplCols
      ->DispatchSoftmaxWarpImplPadding
        ->LaunchSoftmaxWarpImpl
          ->SoftmaxWarpImpl(kernel)

接下來(lái)將從上到下逐個(gè)解讀其實(shí)現(xiàn)細(xì)節(jié)。

3.3 DispatchSoftmaxWarpImpl

該函數(shù)被 DispatchSoftmax 函數(shù)調(diào)用,其內(nèi)部邏輯非常簡(jiǎn)單,實(shí)例化了一個(gè) DispatchSoftmaxWarpImplPackSize 類(lèi)并調(diào)用了其重載的()函數(shù),所有的參數(shù)都是透?jìng)鳎瑳](méi)有其他邏輯。

template
inline cudaError_t DispatchSoftmaxWarpImpl(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) {
  return DispatchSoftmaxWarpImplPackSize()(stream, load, store, rows, cols);
}
,>

3.4 DispatchSoftmaxWarpImplPackSize

顧名思義,pack_size 參數(shù)是在這個(gè)結(jié)構(gòu)體內(nèi)部確定的。該結(jié)構(gòu)體內(nèi)部重載了一個(gè)小括號(hào)運(yùn)算符,其函數(shù)內(nèi)部只做了一件事,對(duì)矩陣的列數(shù)進(jìn)行判斷,如果是偶數(shù),pack_size 取 2,否則取 1。

template
struct DispatchSoftmaxWarpImplPackSize {
  cudaError_t operator()(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) {
    if (cols % 2 == 0) {
      return DispatchSoftmaxWarpImplCols(stream, load, store, rows, cols);
    } else {
      return DispatchSoftmaxWarpImplCols(stream, load, store, rows, cols);
    }
  }
};
,>,>

筆者讀到這里不禁產(chǎn)生了疑問(wèn),前面說(shuō)過(guò)數(shù)據(jù) Pack 后可以提升 GPU 訪(fǎng)問(wèn)帶寬,但是在該函數(shù)中 pack_size 最大也只能取到 2,在前面的文章中筆者提到過(guò)在 cuda 中最大支持一次 128 bit的讀寫(xiě),意味著針對(duì) float 類(lèi)型 pack_size 最大可以取 4,對(duì) half 類(lèi)型甚至可以取 8。所以帶著這個(gè)疑問(wèn)筆者咨詢(xún)了官方源碼的作者俊丞大佬,答曰可以取更大的 pack_size,這里是考慮到更多的特化會(huì)導(dǎo)致編譯時(shí)間過(guò)長(zhǎng)所以只實(shí)現(xiàn)了 2 個(gè)模板。獲得解答后,筆者自行實(shí)現(xiàn)了一個(gè) pack_size = 4 的模板,然后經(jīng)過(guò)實(shí)測(cè)(矩陣大小為 1024*1024, 32*16)發(fā)現(xiàn), pack_size 取 4 和取 2 相比幾乎沒(méi)有提升。。。倒是取 2 相比取 1 有 6% 的提升。猜測(cè)可能是 pack_size 影響了 DispatchSoftmaxWarpImplCols 這個(gè) kernel 的啟動(dòng)參數(shù),所以間接影響了性能,這里官方肯定做過(guò)一系列測(cè)試。。。

3.5 DispatchSoftmaxWarpImplCols

DispatchSoftmaxWarpImplCols 函數(shù)代碼比較長(zhǎng),讀起來(lái)稍顯晦澀,要理解它的實(shí)現(xiàn)邏輯,我們可以換個(gè)思路,看它的目的是什么,然后倒推它的實(shí)現(xiàn)過(guò)程。很顯然,該函數(shù)在最后調(diào)用了 DispatchSoftmaxWarpImplPadding 函數(shù),那么我們就來(lái)看被調(diào)用的函數(shù)需要哪些參數(shù),DispatchSoftmaxWarpImplCols 的作用就是確定這些參數(shù)。讀了 DispatchSoftmaxWarpImplPadding 的參數(shù)列表我們可以發(fā)現(xiàn),有三個(gè)重要參數(shù)需要傳入:cols_per_thread, thread_group_width, rows_per_access,這里先對(duì)這三個(gè)參數(shù)做一個(gè)解釋?zhuān)?/p>

cols_per_thread:每個(gè)線(xiàn)程處理的元素列數(shù)

thread_group_width:線(xiàn)程組的大小,一個(gè)線(xiàn)程組要處理整行的數(shù)據(jù)

rows_per_access:每個(gè)線(xiàn)程組一次處理的行數(shù)

函數(shù)體內(nèi)主要是針對(duì) cols 的大小做了分支,前后代碼有一個(gè)分水嶺,即 cols <= 32 * pack_size,可以分開(kāi)來(lái)看。
當(dāng) cols <= 32 * pack_size 時(shí),thread_group_width 取 2 的 n 次冪,從 1 到 32 一直判斷,如果 cols <= (thread_group_width)*pack_size 那么 thread_group_width 就取當(dāng)前的值。cols_per_thread 取 pack_size,就是說(shuō)當(dāng)前一個(gè)線(xiàn)程只處理一個(gè) Pack 寬度的數(shù)據(jù),這時(shí)候數(shù)據(jù)量也比較小,所以對(duì) rows 也做了一層判斷,如果 rows 是偶數(shù),那么 rows_per_access 取 2,每個(gè)線(xiàn)程一次處理 2 行數(shù)據(jù),否則一次只處理 1 行。
當(dāng) cols > 32 * pack_size 時(shí),這種屬于數(shù)據(jù)量比較大的情況,所以 thread_group_width 直接取能取到的最大值 32,即 Warp 的大小。每個(gè)線(xiàn)程也要處理多個(gè) Pack,cols_per_thread 取 pack_size 的整數(shù)倍,直到 32 * cols_per_thread = 1024,一直判斷 cols <= 32 * cols_per_thread,如果滿(mǎn)足條件,cols_per_thread 就取當(dāng)前值。對(duì)于 rows_per_access 參數(shù),直接取 1,即每個(gè)線(xiàn)程一次只處理 1 行數(shù)據(jù)。
至此函數(shù)邏輯就介紹完了,這個(gè)函數(shù)里有兩個(gè)宏,不熟悉 C++ 的讀者讀起來(lái)可能沒(méi)那么順暢,這里推薦一個(gè)網(wǎng)站(https://cppinsights.io/),從編譯器的角度將 C++ 源碼展開(kāi)顯示,對(duì)閱讀泛型編程和宏這類(lèi)代碼很有幫助。

template
typename std::enable_if::type DispatchSoftmaxWarpImplCols(
    cudaStream_t stream, LOAD load, STORE store, const int64_t rows, const int64_t cols) {
  if (cols <= 0) { return cudaErrorInvalidValue; }
#define DEFINE_ONE_ELIF(thread_group_width)                                                        
  else if (cols <= (thread_group_width)*pack_size) {                                               
    if (rows % 2 == 0) {                                                                           
      return DispatchSoftmaxWarpImplPadding(stream, load, store, 
                                                                              rows, cols);         
    } else {                                                                                       
      return DispatchSoftmaxWarpImplPadding(stream, load, store, 
                                                                              rows, cols);         
    }                                                                                              
  }
  DEFINE_ONE_ELIF(1)
  DEFINE_ONE_ELIF(2)
  DEFINE_ONE_ELIF(4)
  DEFINE_ONE_ELIF(8)
  DEFINE_ONE_ELIF(16)
  DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(col)                                                                      
  else if (cols <= (col)*kWarpSize) {                                                             
    return DispatchSoftmaxWarpImplPadding(stream, load, store, rows, cols);            
  }
  DEFINE_ONE_ELIF(4)
  DEFINE_ONE_ELIF(6)
  DEFINE_ONE_ELIF(8)
  DEFINE_ONE_ELIF(10)
  DEFINE_ONE_ELIF(12)
  DEFINE_ONE_ELIF(14)
  DEFINE_ONE_ELIF(16)
  DEFINE_ONE_ELIF(18)
  DEFINE_ONE_ELIF(20)
  DEFINE_ONE_ELIF(22)
  DEFINE_ONE_ELIF(24)
  DEFINE_ONE_ELIF(26)
  DEFINE_ONE_ELIF(28)
  DEFINE_ONE_ELIF(30)
  DEFINE_ONE_ELIF(32)
#undef DEFINE_ONE_ELIF
  else {
    return cudaErrorInvalidValue;
  }
}
,>,>,>

3.6 DispatchSoftmaxWarpImplPadding

顧名思義,這個(gè)函數(shù)內(nèi)部的邏輯跟 padding 相關(guān),實(shí)際上這個(gè)函數(shù)只做了一件事,當(dāng) cols == cols_per_thread * thread_group_width 時(shí)說(shuō)明矩陣列數(shù)能被線(xiàn)程組均分,這時(shí)候不需要 padding,否則需要 padding。

template
inline cudaError_t DispatchSoftmaxWarpImplPadding(cudaStream_t stream, LOAD load, STORE store,
                                                  const int64_t rows, const int64_t cols) {
  if (cols == cols_per_thread * thread_group_width) {
    return LaunchSoftmaxWarpImpl(
        stream, load, store, rows, cols);
  } else {
    return LaunchSoftmaxWarpImpl(
        stream, load, store, rows, cols);
  }
}
,>,>

3.7 LaunchSoftmaxWarpImpl

該函數(shù)是核函數(shù)的啟動(dòng)函數(shù),函數(shù)內(nèi)主要是確定 block_size、num_blocks 這兩個(gè)參數(shù)。這兩個(gè)參數(shù)的確定筆者在上一篇文章【CUDA編程】OneFlow Element-Wise 算子源碼解讀中有詳細(xì)介紹,有興趣的讀者可以移步,這里不再贅述。
函數(shù)中定義了一個(gè) block_dim 對(duì)象,從初始化參數(shù)可以看出這是一個(gè)二維的 block,寬是 thread_group_width,高取 thread_groups_per_block。從核函數(shù)啟動(dòng)參數(shù) grid_dim_x 可以看出網(wǎng)格是一維的,由此我們可以確定 cuda 線(xiàn)程網(wǎng)格的形狀。這里筆者給出示意圖如下。

59dfac90-ad3c-11ee-8b88-92fbcf53809c.png

圖2 線(xiàn)程網(wǎng)格示意圖

template
inline cudaError_t LaunchSoftmaxWarpImpl(cudaStream_t stream, LOAD load, STORE store,
                                         const int64_t rows, const int64_t cols) {
  constexpr int block_size = 128;
  constexpr int waves = 32;
  static_assert(block_size % thread_group_width == 0, "");
  constexpr int thread_groups_per_block = block_size / thread_group_width;
  dim3 block_dim(thread_group_width, thread_groups_per_block);
  const int64_t num_blocks =
      (rows / rows_per_access + thread_groups_per_block - 1) / thread_groups_per_block;
  int grid_dim_x;
  {
    cudaError_t err = GetNumBlocks(block_size, num_blocks, waves, &grid_dim_x);
    if (err != cudaSuccess) { return err; }
  }
  SoftmaxWarpImpl
      <<>>(load, store, rows, cols);
  return cudaPeekAtLastError();
}
,>,>

3.8 核函數(shù) SoftmaxWarpImpl

接下來(lái)就是 WarpSoftmax 的核函數(shù) SoftmaxWarpImpl,該函數(shù)體內(nèi)部實(shí)現(xiàn)了 Softmax 的核心計(jì)算邏輯。

template
__global__ void SoftmaxWarpImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols) {
  static_assert(cols_per_thread % pack_size == 0, "");  // 確保每個(gè)thread處理的元素個(gè)數(shù)正好被完全pack
  static_assert(thread_group_width <= kWarpSize, "");   // 處理元素的線(xiàn)程組的寬度需要小于等于kWarpSize,并且需要被kWarpSize整除
  static_assert(kWarpSize % thread_group_width == 0, "");
  constexpr int num_packs = cols_per_thread / pack_size;  // 每個(gè)線(xiàn)程處理的 pack 的數(shù)目,即每個(gè)線(xiàn)程需要處理的元素個(gè)數(shù) / pack_size
  assert(cols <= cols_per_thread * thread_group_width);   // 確保一個(gè)thread group 能處理的列數(shù)大于等于一行
  ComputeType buf[rows_per_access][cols_per_thread];  // 聲明寄存器大小,這是一個(gè)二維數(shù)組
  const int global_thread_group_id = blockIdx.x * blockDim.y + threadIdx.y;   // 當(dāng)前warp的全局index
  const int num_global_thread_group = gridDim.x * blockDim.y;   // warp的總數(shù)量
  const int lane_id = threadIdx.x;    // warp內(nèi)的線(xiàn)程id
  const int64_t step = num_global_thread_group * rows_per_access;   // 處理的行數(shù)步長(zhǎng)
  // for 循環(huán)的開(kāi)始為 row = 全局的線(xiàn)程組id * 每個(gè)線(xiàn)程組一次處理的行數(shù),結(jié)束為總行數(shù)
  for (int64_t row = global_thread_group_id * rows_per_access; row < rows; row += step) {
    // 寄存器中開(kāi)辟一塊內(nèi)存記錄當(dāng)前線(xiàn)程組處理的每一行的最大值
    ComputeType thread_max[rows_per_access];
    // 對(duì)每一行的循環(huán)
    #pragma unroll
    for (int row_id = 0; row_id < rows_per_access; ++row_id) {
      // 把當(dāng)前行最小值初始化為 -inf
      thread_max[row_id] = -Inf();
      // 獲取第 row_id 行的指針
      ComputeType* row_buf = buf[row_id];
      #pragma unroll
      for (int pack_id = 0; pack_id < num_packs; ++pack_id) {
        const int pack_offset = pack_id * pack_size;
        // 相鄰的線(xiàn)程讀取相鄰的pack,也就是說(shuō)同一個(gè)線(xiàn)程處理的相鄰pack間間隔是thread_group_width*pack_size
        const int col = (pack_id * thread_group_width + lane_id) * pack_size;
        if (!padding || col < cols) {
          // 使用 obj.template 調(diào)用函數(shù)模板防止歧義,load 一個(gè) pack 的數(shù)據(jù)到寄存器
          load.template load(row_buf + pack_offset, row + row_id, col);
          #pragma unroll
          for (int i = 0; i < pack_size; ++i) {
            thread_max[row_id] = max(thread_max[row_id], row_buf[pack_offset + i]);
          }
        } else {  // 需要 padding 且 col > cols,這種情況對(duì)于第 col 列的數(shù)據(jù)直接將 row_buf 賦最新小值,不影響 thread_max 計(jì)算即可
          #pragma unroll
          for (int i = 0; i < pack_size; ++i) { row_buf[pack_offset + i] = -Inf(); }
        }
      }
    }
    // 記錄屬于同一個(gè)warp的線(xiàn)程組的每一行的最大值,也就是需要進(jìn)行一次warpReduce max
    ComputeType warp_max[rows_per_access];
    #pragma unroll
    for (int row_id = 0; row_id < rows_per_access; ++row_id) {
      // 通過(guò)線(xiàn)程束洗牌函數(shù)對(duì)一個(gè)線(xiàn)程組內(nèi)的所有線(xiàn)程的 thread_max 求規(guī)約得到一個(gè)線(xiàn)程組處理的每一行的最大值
      warp_max[row_id] = WarpAllReduce(thread_max[row_id]);
    }
    // 記錄當(dāng)前線(xiàn)程組處理的每一行的sum
    ComputeType thread_sum[rows_per_access];
    #pragma unroll
    for (int row_id = 0; row_id < rows_per_access; ++row_id) {
      thread_sum[row_id] = 0;
      ComputeType* row_buf = buf[row_id];
      #pragma unroll
      for (int i = 0; i < cols_per_thread; ++i) {
        if (algorithm == Algorithm::kSoftmax) {
          row_buf[i] = Exp(row_buf[i] - warp_max[row_id]);
          thread_sum[row_id] += row_buf[i];
        } else if (algorithm == Algorithm::kLogSoftmax) {
          row_buf[i] -= warp_max[row_id];
          thread_sum[row_id] += Exp(row_buf[i]);
        } else {
          __trap();   // 內(nèi)核的執(zhí)行被中止并在主機(jī)程序中引發(fā)中斷。
        }
      }
    }
    ComputeType warp_sum[rows_per_access];
    #pragma unroll
    for (int row_id = 0; row_id < rows_per_access; ++row_id) {
      warp_sum[row_id] = WarpAllReduce(thread_sum[row_id]);
    }
    #pragma unroll
    for (int row_id = 0; row_id < rows_per_access; ++row_id) {
      ComputeType* row_buf = buf[row_id];
      #pragma unroll
      for (int i = 0; i < cols_per_thread; ++i) {
        if (algorithm == Algorithm::kSoftmax) {
          row_buf[i] = Div(row_buf[i], warp_sum[row_id]);
        } else if (algorithm == Algorithm::kLogSoftmax) {
          row_buf[i] -= Log(warp_sum[row_id]);
        } else {
          __trap();
        }
      }
      #pragma unroll
      for (int i = 0; i < num_packs; ++i) {
        const int col = (i * thread_group_width + lane_id) * pack_size;
        if (!padding || col < cols) {
          store.template store(row_buf + i * pack_size, row + row_id, col);
        }
      }
    }
  }
}
,>,>

具體代碼如上,在解讀之前,需要先介紹一下幾個(gè)重要參數(shù)的意義。

algorithm:代表所使用的的算法,有 Algorithm::kSoftmax 和 Algorithm::kLogSoftmax。

global_thread_group_id:當(dāng)前線(xiàn)程組的全局索引

lane_id:當(dāng)前線(xiàn)程在線(xiàn)程組內(nèi)的索引

首先在核函數(shù)內(nèi)部做了幾個(gè)編譯期斷言操作,確保核函數(shù)能夠正常啟動(dòng)。然后在寄存器中定義了一個(gè)二維數(shù)組 buf[rows_per_access][cols_per_thread] 用來(lái)存儲(chǔ)矩陣中的數(shù)據(jù),我們知道,寄存器中的變量只能對(duì)當(dāng)前線(xiàn)程可見(jiàn),每個(gè)線(xiàn)程中都有一個(gè)變量 buf,但是存儲(chǔ)的值可以不同,這里是為了減少對(duì)全局內(nèi)存的讀取,所以給每個(gè)線(xiàn)程都定義一個(gè)寄存器變量用于存儲(chǔ)該線(xiàn)程處理的矩陣元素。
接著是一個(gè) Grip-loop 的循環(huán),因?yàn)橛锌赡芫仃囆袛?shù)過(guò)大導(dǎo)致前面求 num_blocks 的時(shí)候是根據(jù)硬件參數(shù)選取的,這時(shí)候每個(gè)線(xiàn)程不止處理一次,所以循環(huán)步長(zhǎng)設(shè)置為網(wǎng)格大小。Grip-loop 內(nèi)部定義了一個(gè)寄存器變量 thread_max[rows_per_access],這個(gè)數(shù)組用來(lái)存儲(chǔ)當(dāng)前線(xiàn)程處理的元素中的每一行的最大值。接下來(lái)就是一個(gè) reduceMax 操作。
(1)reduceMax
如圖 2,每個(gè)線(xiàn)程處理了多個(gè) Pack 的數(shù)據(jù),求最大值需要兩層循環(huán)。第一層循環(huán)中把一個(gè) Pack 的矩陣元素 load 到 buf 數(shù)組中,這里主要是要理解 col 變量的含義,結(jié)合圖 2 的示意圖不難理解,相鄰的線(xiàn)程讀取相鄰的 Pack 的目的是讓一個(gè)線(xiàn)程束中各線(xiàn)程單次訪(fǎng)問(wèn)的數(shù)據(jù)在內(nèi)存中相鄰,這是一個(gè)合并訪(fǎng)問(wèn)的概念,目的是提升訪(fǎng)問(wèn)效率。第二層循環(huán)中對(duì)單個(gè) Pack 中的元素求最大值存到 thread_max 中。
注意,這時(shí)候 thread_max 中存的只是每個(gè)線(xiàn)程內(nèi)部處理的元素的最大值,但是 reduceMax 操作要獲取的是矩陣每一行的最大值,由于 WarpSoftmax 的應(yīng)用范圍就是一個(gè)線(xiàn)程組處理一行數(shù)據(jù),所以再對(duì)線(xiàn)程組內(nèi)所有的 thread_max 求最大值即可。前面說(shuō)過(guò),每個(gè)線(xiàn)程內(nèi)部都有一個(gè) thread_max 變量,對(duì)這些變量求最大值,必然要在線(xiàn)程間進(jìn)行通信,源碼中使用了 WarpAllReduce() 函數(shù)完成了這一操作得到了矩陣每一行的最大值 warp_max,核心就是利用了線(xiàn)程束洗牌指令 __shfl_xor_sync 完成了一個(gè)束內(nèi)折半規(guī)約操作,筆者之前在另一篇文章也有介紹:【CUDA編程】CUDA編程中的并行規(guī)約問(wèn)題。有興趣的讀者可以去 cuda 官網(wǎng)詳細(xì)了解一下束內(nèi)洗牌指令的用法,當(dāng)然了這里也可以直接使用共享內(nèi)存存儲(chǔ)數(shù)據(jù),我們知道共享內(nèi)存在整個(gè) block 都是可見(jiàn)的,也就不需要使用束內(nèi)通信,但是從訪(fǎng)問(wèn)性能出發(fā),共享內(nèi)存是不如寄存器快的,所以 oneflow 選擇了寄存器。,>

template class ReductionOp, typename T, int thread_group_width = kWarpSize>
__inline__ __device__ T WarpAllReduce(T val) {
  for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
    val = ReductionOp()(val, __shfl_xor_sync(0xffffffff, val, mask));
  }
  return val;
}

(2)reduceSum
接下來(lái)就是 reduceSum 操作,這里源碼提供了兩種算法: Algorithm::kSoftmax 和 Algorithm::kLogSoftmax。kSoftmax 就是公式(2)中的計(jì)算公式,kLogSoftmax 計(jì)算的是 計(jì)算公式如下:

reduceSum 的計(jì)算思路和 reduceMax 相同,先在寄存器定義一個(gè)變量 thread_sum 然后求出各個(gè)線(xiàn)程內(nèi)的指數(shù)和,最后束內(nèi)規(guī)約求每一行的指數(shù)和 warp_sum。
broadcastSub、exp、broadcastDiv 這三個(gè)操作比較簡(jiǎn)單,其邏輯就直接包含在兩個(gè)規(guī)約操作的實(shí)現(xiàn)代碼里,這里不再贅述,至此 WarpSoftmax 源碼解讀完畢,有興趣的讀者可以自行嘗試。調(diào)用時(shí)可以將矩陣 cols 限制在 1024 以?xún)?nèi)調(diào)用 DispatchSoftmax 函數(shù),也可以直接調(diào)用 DispatchSoftmaxWarpImpl 函數(shù)。

4 小結(jié)

總結(jié)一下 WarpSoftmax 源碼中的一些值得注意的內(nèi)容。

數(shù)據(jù) Pack 可以有效地提升訪(fǎng)問(wèn)帶寬,pack_size 可以根據(jù) cuda 中最大支持一次 128 bit 的讀寫(xiě)來(lái)確定最大值。

WarpSoftmax 的核心就是束內(nèi)規(guī)約,利用了束內(nèi)線(xiàn)程可互相訪(fǎng)問(wèn)寄存器的特性提高效率,但受制于單個(gè)線(xiàn)程可使用的寄存器大小,所以 WarpSoftmax 不適用于矩陣列數(shù)比較大的場(chǎng)景。

源碼中對(duì)于 pack_size 和 row_per_access 的確定都比較簡(jiǎn)單粗暴,可以進(jìn)行更細(xì)致的處理。

審核編輯:湯梓紅

聲明:本文內(nèi)容及配圖由入駐作者撰寫(xiě)或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀(guān)點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場(chǎng)。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問(wèn)題,請(qǐng)聯(lián)系本站處理。 舉報(bào)投訴
  • 源碼
    +關(guān)注

    關(guān)注

    8

    文章

    632

    瀏覽量

    29110
  • 模型
    +關(guān)注

    關(guān)注

    1

    文章

    3112

    瀏覽量

    48658
  • 深度學(xué)習(xí)
    +關(guān)注

    關(guān)注

    73

    文章

    5463

    瀏覽量

    120890
  • OneFlow
    +關(guān)注

    關(guān)注

    0

    文章

    9

    瀏覽量

    8786

原文標(biāo)題:【CUDA編程】OneFlow Softmax 算子源碼解讀之WarpSoftmax

文章出處:【微信號(hào):GiantPandaCV,微信公眾號(hào):GiantPandaCV】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。

收藏 人收藏

    評(píng)論

    相關(guān)推薦

    OneFlow Softmax算子源碼解讀BlockSoftmax

    寫(xiě)在前面:筆者這段時(shí)間工作太忙,身心俱疲,博客停更了一段時(shí)間,現(xiàn)在重新?lián)炱饋?lái)。本文主要解讀 OneFlow 框架的第二種 Softmax 源碼實(shí)現(xiàn)細(xì)節(jié),即 block 級(jí)別的
    的頭像 發(fā)表于 01-08 09:26 ?642次閱讀
    <b class='flag-5'>OneFlow</b> <b class='flag-5'>Softmax</b><b class='flag-5'>算子</b><b class='flag-5'>源碼</b><b class='flag-5'>解讀</b><b class='flag-5'>之</b>BlockSoftmax

    caffe源碼解讀《九》softmax

    編程語(yǔ)言行業(yè)芯事經(jīng)驗(yàn)分享
    蒙特卡洛家的樹(shù)
    發(fā)布于 :2022年03月09日 15:34:58

    TensorFlow、PyTorch,“后浪”OneFlow 有沒(méi)有機(jī)會(huì)

    TensorFlow、PyTorch,“后浪”OneFlow 有沒(méi)有機(jī)會(huì) | 一流科技工程師成誠(chéng)編者按:7月31日,一流科技在創(chuàng)業(yè)1300天后,他們宣布開(kāi)源自研的深度學(xué)習(xí)框架OneFlow,此前,CSDN對(duì)CEO袁進(jìn)輝進(jìn)行了專(zhuān)訪(fǎng)。本文中,一流科技工程師成...
    發(fā)表于 07-27 08:24

    機(jī)器學(xué)習(xí)的Softmax定義和優(yōu)點(diǎn)

    Softmax在機(jī)器學(xué)習(xí)中有非常廣泛的應(yīng)用,但是剛剛接觸機(jī)器學(xué)習(xí)的人可能對(duì)Softmax的特點(diǎn)以及好處并不理解,其實(shí)你了解了以后就會(huì)發(fā)現(xiàn),Softmax計(jì)算簡(jiǎn)單,效果顯著,非常好用。
    的頭像 發(fā)表于 03-15 17:18 ?4631次閱讀
    機(jī)器學(xué)習(xí)的<b class='flag-5'>Softmax</b>定義和優(yōu)點(diǎn)

    使用Softmax的信息來(lái)教學(xué) —— 知識(shí)蒸餾

    當(dāng)處理一個(gè)分類(lèi)問(wèn)題時(shí),使用softmax作為神經(jīng)網(wǎng)絡(luò)的最后一個(gè)激活單元是非常典型的用法。這是為什么呢?因?yàn)?b class='flag-5'>softmax函數(shù)接受一組logit為輸入并輸出離散類(lèi)別上的概率分布。
    的頭像 發(fā)表于 10-10 10:23 ?2094次閱讀

    基于EAIDK的人臉?biāo)惴☉?yīng)用-源碼解讀(2)

    上一期介紹了基于EAIDK的人臉?biāo)惴☉?yīng)用,本期從應(yīng)用角度,解讀一下該案例源碼。本期案例源碼解讀,主要從源碼目錄結(jié)構(gòu)、配置文件、模型目...
    的頭像 發(fā)表于 12-10 21:14 ?835次閱讀

    開(kāi)源軟件-OneFlow通用深度學(xué)習(xí)框架

    ./oschina_soft/oneflow.zip
    發(fā)表于 06-20 09:26 ?2次下載
    開(kāi)源軟件-<b class='flag-5'>OneFlow</b>通用深度學(xué)習(xí)框架

    Sobel算子原理介紹與實(shí)現(xiàn)方法

    索貝爾算子(Sobel operator)主要用作邊緣檢測(cè),在技術(shù)上,它是一離散性差分算子,用來(lái)運(yùn)算圖像亮度函數(shù)的灰度近似值。在圖像的任何一點(diǎn)使用此算子,將會(huì)產(chǎn)生對(duì)應(yīng)的灰度矢量或是其
    的頭像 發(fā)表于 07-21 17:27 ?1.3w次閱讀

    flowflops:OneFlow模型的Flops計(jì)算

    用于計(jì)算 OneFlow 模型的 FLOPs 和 Parameters 的第三方庫(kù)。
    的頭像 發(fā)表于 11-16 10:04 ?1145次閱讀

    解析OneFlow Element-Wise算子實(shí)現(xiàn)方法

    雖然這種寫(xiě)法非常簡(jiǎn)單明了,但卻存在明顯的性能問(wèn)題。所以這篇文章將基于OneFlow開(kāi)源的Element-Wise CUDA算子方案來(lái)解釋如何寫(xiě)一個(gè)高性能的Element-Wise CUDA算子。
    的頭像 發(fā)表于 12-12 10:54 ?1484次閱讀

    解析OneFlow BatchNorm相關(guān)算子實(shí)現(xiàn)

    可以看到 CUDNN_BATCHNORM_PER_ACTIVATION 被用于非卷積層,在OneFlow中只有當(dāng)輸入Tensor的維度為2時(shí)才選取這種模式。而
    的頭像 發(fā)表于 12-23 15:08 ?616次閱讀

    深度學(xué)習(xí)編譯器Layerout Transform優(yōu)化

    繼續(xù)深度學(xué)習(xí)編譯器的優(yōu)化工作解讀,本篇文章要介紹的是OneFlow系統(tǒng)中如何基于MLIR實(shí)現(xiàn)Layerout Transform。
    的頭像 發(fā)表于 05-18 17:32 ?683次閱讀

    PyTorch教程4.1Softmax回歸

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程4.1Softmax回歸.pdf》資料免費(fèi)下載
    發(fā)表于 06-05 15:46 ?0次下載
    PyTorch教程4.1<b class='flag-5'>之</b><b class='flag-5'>Softmax</b>回歸

    PyTorch教程4.4從頭開(kāi)始實(shí)現(xiàn)Softmax回歸

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程4.4從頭開(kāi)始實(shí)現(xiàn)Softmax回歸.pdf》資料免費(fèi)下載
    發(fā)表于 06-05 15:37 ?0次下載
    PyTorch教程4.4<b class='flag-5'>之</b>從頭開(kāi)始實(shí)現(xiàn)<b class='flag-5'>Softmax</b>回歸

    使用LabVIEW人工智能視覺(jué)工具包快速實(shí)現(xiàn)傳統(tǒng)Opencv算子的調(diào)用源碼

    電子發(fā)燒友網(wǎng)站提供《使用LabVIEW人工智能視覺(jué)工具包快速實(shí)現(xiàn)傳統(tǒng)Opencv算子的調(diào)用源碼.rar》資料免費(fèi)下載
    發(fā)表于 09-28 17:38 ?13次下載