0
  • 聊天消息
  • 系統(tǒng)消息
  • 評論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會(huì)員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識(shí)你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

深入淺出理解PagedAttention CUDA實(shí)現(xiàn)

深度學(xué)習(xí)自然語言處理 ? 來源:PaperWeekly ? 2024-01-09 11:43 ? 次閱讀

vLLM 中,LLM 推理的 prefill 階段 attention 計(jì)算使用第三方庫 xformers 的優(yōu)化實(shí)現(xiàn),decoding 階段 attention 計(jì)算則使用項(xiàng)目編譯 CUDA 代碼實(shí)現(xiàn)。具體代碼在 vllm 的 csrc/attention/attention_kernels.cu 文件里,開發(fā)者洋洋灑灑寫了八百多行 CUDA 代碼。

Attention 計(jì)算時(shí)使用頁式(paged)管理 KVCache 用于增加服務(wù)吞吐率,但對延遲有負(fù)面影響,因此高效的 PA 實(shí)現(xiàn)方法,利用頁式內(nèi)存管理同時(shí)盡量降低其負(fù)面影響,對框架的綜合性能表現(xiàn)至關(guān)重要。

本文章將描述 PA CUDA Kernel 的實(shí)現(xiàn)細(xì)節(jié),這些細(xì)節(jié)是公開的論文和博客所不涉及的,但卻對框架的速度至關(guān)重要。另外,PA 實(shí)現(xiàn)改編自 FasterTransformers 某個(gè)版本的 MHA 實(shí)現(xiàn),NV 原始版本對 GPU 特性的運(yùn)用也是相當(dāng)老道的,值得大家借鑒。

vLLM 中有兩個(gè)版本 PA,使用一個(gè)簡單的啟發(fā)式方法來決定是使用 V1 還是 V2 版本。V1 是本文介紹的版本,改編自 FasterTransformers 的 MHA 實(shí)現(xiàn)。V2 是參考 FlashDecoding 方式進(jìn)行實(shí)現(xiàn),對 sequence 維度進(jìn)行切分以增加并行粒度,關(guān)于 FlashDecoding 可以參考本人知乎文章。V1 適合長度小于 8192 或者 num_seqs * num_heads>512 的情況。

參數(shù)定義和數(shù)據(jù)結(jié)構(gòu)

num_seq:本次推理請求 sequence 數(shù)目。

num_head:Query 的 head 數(shù)目。

num_kv_heads:Key、Value 的 head 數(shù)目,對于 MHA 和 num_head 相同,如果是 GQA、MQA 則 num_kv_heads 小于 num_head。

head_size hidden dimension,特征的維度。

PA 使用 tensor 的維度信息

out [num_seqs, num_heads, head_size]

Q [num_seqs, num_heads, head_size]

KCache [num_blocks, num_kv_heads, head_size/x, block_size, x]:x 表示一個(gè)向量化的大小,如 float16 -> 16 / sizeof(float16) = 8。

VCache [num_blocks, num_kv_heads, head_size, block_size]

Paged 內(nèi)存管理相關(guān)的輔助數(shù)據(jù)結(jié)構(gòu):

blk_size:也就是 block_size,是 KVCache page 的最高維,KVCache 是若干個(gè) page 的集合,每個(gè) page 存(blk_size, num_head,head_size)個(gè) K、V 的元素。

head_mapping [num_heads] 用于 MQA, GQA,確定用的 KV_head

block_tables [num_seqs, max_num_blocks_per_seq] block_tables 映射表,表示每個(gè) sequence 映射到哪幾個(gè) block 上

context_lens [num_seqs] 用于變長

課前問題

如果你能回答以下兩個(gè)問題,那么說明你已經(jīng)非常熟練地掌握了 PA 實(shí)現(xiàn),并可以用批判性的眼光審閱本文,找出其中可能存在的錯(cuò)誤。如果你暫時(shí)無法回答這些問題,請不要擔(dān)憂,閱讀完本文后會(huì)給你答案。

Q1:為什么 K Cache 的 layout 和 V Cache layout 不一樣?

Q2:PA 實(shí)現(xiàn)和 FlashAttention 有什么區(qū)別?

PagedAttention算子計(jì)算流程

首先,按照 CUDA 編程模型對任務(wù)進(jìn)行并行劃分,grid 大小(num_heads, num_seqs),grid 中每個(gè) CUDA thread block 大?。∟UM_THREADS),NUM_THREADS 是常量默認(rèn)為 128,也就說每個(gè) thread block 包含 128 個(gè)線程,負(fù)責(zé)完成 output 矩陣一行(包含 head_size 個(gè)元素)結(jié)果的 attention 計(jì)算任務(wù)。thread block 中的線程進(jìn)一步劃分若干個(gè)WARP。

眾所周知,WARP 是 GPU 一個(gè)基本的執(zhí)行單元,由 32 個(gè)線程組成,這些線程以 SMIT 方式在硬件上同時(shí)執(zhí)行相同的指令,在不同的數(shù)據(jù)上進(jìn)行操作。在 PA 中比較特殊的是,warp 內(nèi) 32 個(gè)線程進(jìn)一步劃分為 blk_size 個(gè) thread group,這和 paged KVCache 設(shè)計(jì) x 息息相關(guān)的,馬上會(huì)細(xì)講。

Attention 計(jì)算 softmax(QK^T)V,一圖勝前言,后面流程介紹將圍繞下面這幅圖展開。其中 thread block, warp, thread group, thread 別用不同顏色表示。

ed093146-ae34-11ee-8b88-92fbcf53809c.png

▲ 圖1:PagedAttention CUDA計(jì)算流程

在上圖的左側(cè)部分,我們看到了 Q 矩陣,這部分描述了從顯存讀取 Q 數(shù)據(jù)到共享內(nèi)存的過程。在這個(gè)過程中,一個(gè) CUDA 線程塊會(huì)讀取圖中 Q 矩陣的一行(包含 head_size個(gè)元素)并將其存入共享內(nèi)存。

這個(gè)過程是通過一個(gè)循環(huán)來實(shí)現(xiàn)的,在每次迭代中,每個(gè) thread group 會(huì)讀取 16 字節(jié)的 Q 數(shù)據(jù)(例如,如果使用 float16,那么就是 8 個(gè)元素)。每個(gè) warp 會(huì)讀取 16*blk_size 字節(jié)的 Q 數(shù)據(jù),這些數(shù)據(jù)對應(yīng)于一個(gè) sequence 的一個(gè) head,由 CUDA grid 索引指定。當(dāng)循環(huán)訪問結(jié)束后,共享內(nèi)存存儲(chǔ) Q 行的一部分。如下圖所示,綠色部分表示存儲(chǔ)在一個(gè)線程讀入共享內(nèi)存中的數(shù)據(jù)。

ed1a631c-ae34-11ee-8b88-92fbcf53809c.png

圖 1 中上面部分 K 矩陣部分描述了從顯存讀取 K Cache 到寄存器的過程。每個(gè)序列的 K Cache 包含 cxt_length * num_kv_heads * head_size 個(gè)元素,但由于采用了頁式內(nèi)存管理,這些元素在內(nèi)存中的存儲(chǔ)并不連續(xù)。每個(gè) thread block 只負(fù)責(zé)計(jì)算一個(gè) sequence 一個(gè) head 的 QK^T,因此只需要 ctx_length * head_size 個(gè) K Cache 元素。

然而,由于 ctx_length 維度的存儲(chǔ)是不連續(xù)的,并且以 blk_size 個(gè) token 為粒度分布在不同的內(nèi)存地址,我們需要根據(jù)query的head_idx和 seq_idx 訪問 block_table 以找到 K Cache的physical_block_num。為了方便后續(xù)的描述,我們可以將 K Cache 視為(:, head_size)的形狀,其中 head_size 個(gè)元素組成一行。

K Cache 的布局為 [num_blocks, num_kv_heads, head_size/x, block_size, x],這是為了優(yōu)化寫入 shared memory 的操作。在 Q 和 K 矩陣的同一行元素被讀入寄存器并進(jìn)行點(diǎn)乘運(yùn)算后,結(jié)果需要被存入 shared memory。

如果一個(gè) warp 中所有線程都計(jì)算 Q、K 同一行數(shù)據(jù),會(huì)導(dǎo)致寫入 shared memory 的同一個(gè)位置,這將造成 warp 內(nèi)不同線程順序地寫入。因此,為了優(yōu)化,warp的線程最好計(jì)算 Q 和 K 的不同行數(shù)據(jù)。因此,在設(shè)計(jì) K Cache 布局時(shí),我們將 block_size 放在比 head_size 更低的維度。

由于 warp size 大于 block_size,我們需要將 head_size 拆分為 head_size/x 和 x 兩個(gè)維度,借 x 到最低維度,以確保每個(gè)線程讀入的數(shù)據(jù)量和計(jì)算量都足夠大。最后,每個(gè)線程組派一個(gè)線程去寫入 shared memory,這樣一個(gè) warp 有 blk_size 個(gè)線程并行寫入 shared memory,從而增加了 shared memory 的訪問帶寬。這種設(shè)計(jì)策略是為了實(shí)現(xiàn)高效的并行計(jì)算和內(nèi)存訪問,以提高整體的計(jì)算性能。

在代碼實(shí)現(xiàn)中,訪問 K 矩陣需要一個(gè)循環(huán),該循環(huán)使得 CUDA 線程塊中的所有 warp 依次訪問 num_block 個(gè)頁面。在每次循環(huán)迭代中,每個(gè) warp 負(fù)責(zé)訪問連續(xù)的 blk_size個(gè)K Cache 行,這涉及到的數(shù)據(jù)量為 blk_size * head_size 個(gè)元素。同時(shí),每個(gè) thread group 負(fù)責(zé)訪問 K Cache 的一行,將 head_size 個(gè)元素加載到自己的寄存器中。

接著,寄存器中的 Q 和 K 數(shù)據(jù)元素立即進(jìn)行點(diǎn)乘運(yùn)算,運(yùn)算結(jié)果被寫入 shared memory 中。因此,線程塊的 shared memory 存儲(chǔ)了一行 QK^T 的結(jié)果,包含 ctx_length 個(gè)元素。這種實(shí)現(xiàn)方式充分利用了 CUDA 的并行計(jì)算能力,以提高數(shù)據(jù)處理的效率。

然后,thread block 對 shared memory 中元素進(jìn)行 max,sum 方式 reduction,然后計(jì)算得到 softmax 結(jié)果。

圖 1 右邊 V 矩陣部分描述從顯存讀 V Cache 到寄存器過程。和 K Cache 一樣,CUDA thread block 依次訪問 num_blk 個(gè)物理塊到寄存器,每個(gè) warp 負(fù)責(zé) blk_size 個(gè) token 的 page 內(nèi)存,page 的真實(shí)物理地址同樣需要進(jìn)行索引。

不過這里不需要以 thread group 為單位訪問 16 字節(jié),而是每個(gè) thread 訪問 16 字節(jié)的元素。訪問完就可以與 shared memory 的 softmax(QK^T) 中間結(jié)果對應(yīng)位置 16 字節(jié)的數(shù)據(jù)進(jìn)行點(diǎn)乘,得到一個(gè) float 結(jié)果,寫到 output 對應(yīng)位置中。

為什么V Cache的layout是 [num_blocks, num_kv_heads, head_size, block_size],和 K Cache layout 不一樣?這是因?yàn)?V 要去做點(diǎn)乘的對象在shared memory,只需要讀,不涉及并行寫的問題。

和 FlashAttention(FA)有什么不同?結(jié)合我的圖和中間 FAv2 的流程圖對比就一目了然了。FA 用了兩層循環(huán),每次寫一個(gè) Tile 的 output tensor,而 PA 一直只有一層循環(huán),每次寫一行 output tensor。因?yàn)槊看味加姓械?QK^T 中間結(jié)果,不需要 online softmax 這種花哨技巧。

ed257e1e-ae34-11ee-8b88-92fbcf53809c.png

PAv1的問題

以我粗淺的理解指出幾點(diǎn) vLLM PAv1 的問題。一、和 MHA 相比,MQA 和 GAQ 沒有減少對 KV Cache 的讀寫次數(shù)。讀 K、V Cache 時(shí)候只是做了一個(gè) head_idx 的轉(zhuǎn)換,會(huì)重復(fù)從顯存讀相同的 head。二、對于 seq length 很長情況沒法適應(yīng),因?yàn)闆]有沿著 ctx_length 或者 batch 維度做切分。這點(diǎn) FlashAttention 和 FlashDecoding 就做了,因此 PAv2 借鑒了 FA 的切分思想。

總結(jié)

vLLM 的 paged attention v1 實(shí)現(xiàn)繼承自 FasterTransformers MHA 實(shí)現(xiàn),它和 FlashAttention 的并行任務(wù)劃分方式不同。其中對 KVCache layout 的設(shè)計(jì)比較巧妙,充分利用了 shared memory 寫帶寬,是一種常用 CUDA 編程技巧。







審核編輯:劉清

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問題,請聯(lián)系本站處理。 舉報(bào)投訴
  • 寄存器
    +關(guān)注

    關(guān)注

    31

    文章

    5250

    瀏覽量

    119199
  • Cache
    +關(guān)注

    關(guān)注

    0

    文章

    128

    瀏覽量

    28188
  • 內(nèi)存管理
    +關(guān)注

    關(guān)注

    0

    文章

    167

    瀏覽量

    14099
  • MQA
    MQA
    +關(guān)注

    關(guān)注

    0

    文章

    3

    瀏覽量

    6034

原文標(biāo)題:vLLM皇冠上的明珠:深入淺出理解PagedAttention CUDA實(shí)現(xiàn)

文章出處:【微信號(hào):zenRRan,微信公眾號(hào):深度學(xué)習(xí)自然語言處理】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。

收藏 人收藏

    評論

    相關(guān)推薦

    深入淺出AVR

    深入淺出AVR,一本書。
    發(fā)表于 07-15 12:02

    深入淺出玩轉(zhuǎn)FPGA

    深入淺出玩轉(zhuǎn)FPGA
    發(fā)表于 07-21 09:21

    深入淺出ARM7

    深入淺出ARM7
    發(fā)表于 08-18 10:12

    HDMI技術(shù)深入淺出

    HDMI技術(shù)深入淺出
    發(fā)表于 08-19 10:52

    深入淺出Android

    深入淺出Android
    發(fā)表于 08-20 10:14

    深入淺出Android

    深入淺出Android
    發(fā)表于 04-26 10:48

    深入淺出安防視頻監(jiān)控系統(tǒng)

    深入淺出安防視頻監(jiān)控系統(tǒng)深入淺出安防視頻監(jiān)控系統(tǒng)
    發(fā)表于 05-22 19:28

    深入淺出AVR

    深入淺出AVR
    發(fā)表于 08-23 10:10

    深入淺出數(shù)據(jù)分析

    深入淺出數(shù)據(jù)分析,有需要的朋友下來看看。
    發(fā)表于 01-15 14:22 ?0次下載

    深入淺出談多層面板布線技巧

    深入淺出談多層面板布線技巧
    發(fā)表于 12-13 22:20 ?0次下載

    深入淺出Android—Android開發(fā)經(jīng)典教材

    深入淺出Android—Android開發(fā)經(jīng)典教材
    發(fā)表于 10-24 08:52 ?15次下載
    <b class='flag-5'>深入淺出</b>Android—Android開發(fā)經(jīng)典教材

    深入淺出數(shù)字信號(hào)處理

    深入淺出數(shù)字信號(hào)處理
    發(fā)表于 12-07 20:14 ?487次閱讀

    深入淺出理解阻抗匹配

    深入淺出理解阻抗匹配
    的頭像 發(fā)表于 02-03 15:14 ?3960次閱讀

    深入淺出學(xué)習(xí)250個(gè)通信原理資源下載

    深入淺出學(xué)習(xí)250個(gè)通信原理資源下載
    發(fā)表于 04-12 09:16 ?28次下載

    深入淺出學(xué)習(xí)低功耗藍(lán)牙協(xié)議棧

    深入淺出學(xué)習(xí)低功耗藍(lán)牙協(xié)議棧
    發(fā)表于 06-23 10:35 ?56次下載