0
  • 聊天消息
  • 系統(tǒng)消息
  • 評論與回復
登錄后你可以
  • 下載海量資料
  • 學習在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會員中心
創(chuàng)作中心

完善資料讓更多小伙伴認識你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

理解KV cache的作用及優(yōu)化方法

jf_pmFSk4VX ? 來源:知乎 ? 2023-12-04 15:24 ? 次閱讀

作者丨紫氣東來

在 Transformer 的 Encoder-base 的模型(如 BERT系列)中,推理和訓練過程保持了高度的統(tǒng)一性(差異僅僅在于是否存在反向過程)。而在 Decoder-base 的生成式模型(如 GPT系列)中,推理和訓練存在相當大的差異性,主要體現(xiàn)在推理過程具有以下3點特征:

自回歸

兩階段(第一階段輸入 prompt,第二階段輸入上一個生成的token)

KV cache

以上三點實際上也是相輔相成、不可分割的,其中自回歸的生成模式是根本原因,兩階段是外在的體現(xiàn)形式,KV cache是優(yōu)化手段。

下面將通過梳理整個推理過程,來理解 KV cache 的作用及優(yōu)化方法。

一、KV cache 的由來與基本矛盾

885a422a-9125-11ee-939d-92fbcf53809c.png

第一階段(prompt 輸入):

88749c6a-9125-11ee-939d-92fbcf53809c.png

8884c4a0-9125-11ee-939d-92fbcf53809c.png

888bc2aa-9125-11ee-939d-92fbcf53809c.png

889cd7f2-9125-11ee-939d-92fbcf53809c.jpg

KV cache 作用過程

第二階段(token by token):

88b78a48-9125-11ee-939d-92fbcf53809c.png

88bbef52-9125-11ee-939d-92fbcf53809c.png

88c97186-9125-11ee-939d-92fbcf53809c.png

KV cache的顯存占用分析

88d47d88-9125-11ee-939d-92fbcf53809c.png

88e1e108-9125-11ee-939d-92fbcf53809c.png

batch size s+n KV cache(GB) KV cache/weight
4 4096 81 0.23
16 4096 324 0.93
64 4096 1297 3.71

可見隨著 batch size 和 長度的增大,KV cache 占用的顯存開銷快速增大,甚至會超過模型本身。

而 LLM 的窗口長度也在不斷增大,因此就出現(xiàn)一組主要矛盾,即:對不斷增長的 LLM 的窗口長度的需要與有限的 GPU 顯存之間的矛盾。因此優(yōu)化 KV cache 就顯得非常必要。

二、KV cache 優(yōu)化的典型方法

2.1 共用 KV cache:MQA,GQA

MQA (Multi Query Attention,多查詢注意力) 是多頭注意力的一種變體。其主要區(qū)別在于,在 MQA 中不同的注意力頭共享一個K和V的集合,每個頭只單獨保留了一份查詢參數(shù)。因此K和V的矩陣僅有一份,這大幅度減少了顯存占用,使其更高效。由于MQA改變了注意力機制的結(jié)構(gòu),因此模型通常需要從訓練開始就支持 MQA 。也可以通過對已經(jīng)訓練好的模型進行微調(diào)來添加多查詢注意力支持,僅需要約 5% 的原始訓練數(shù)據(jù)量 就可以達到不錯的效果。包括 Falcon、SantaCoder、StarCoder 等在內(nèi)很多模型都采用了 MQA 機制。

# Multi Head Attention
self.Wqkv = nn.Linear(     # Multi-Head Attention 的創(chuàng)建方法
    self.d_model,
    3 * self.d_model,     # Q、K和V 3 個矩陣, 所以是 3 * d_model
    device=device
)
query, key, value = qkv.chunk(3, dim=2)      # 每個 tensor 都是 (1, 512, 768)

# Multi Query Attention
self.Wqkv = nn.Linear(       # Multi-Query Attention 的創(chuàng)建方法
    d_model,
    d_model + 2 * self.head_dim,    # 只創(chuàng)建Q的頭向量,所以是 1* d_model, 而K和V不再具備單獨的頭向量, 所以是 2 * self.head_dim
    device=device,
)
query, key, value = qkv.split(
    [self.d_model, self.head_dim, self.head_dim],    # query -> (1, 512, 768), key   -> (1, 512, 96), value -> (1, 512, 96)
    dim=2
)

88ec3ba8-9125-11ee-939d-92fbcf53809c.jpg

MHA v.s. GQA v.s. MQA

GQA(Grouped Query Attention,分組查詢注意力)是一種介于多頭注意力和 MQA 之間的折中方案。它將查詢頭(Query Heads)分組,并在每組中共享一個鍵頭(Key Head)和一個值頭(Value Head)。表達能力與推理速度:GQA既保留了多頭注意力的一定表達能力,又通過減少內(nèi)存訪問壓力來加速推理速度。

88f5d97e-9125-11ee-939d-92fbcf53809c.jpg

MHA, GQA, MQA 性能比較

2.2 窗口優(yōu)化

890f5b60-9125-11ee-939d-92fbcf53809c.png

891f68b6-9125-11ee-939d-92fbcf53809c.jpg

3)箭型 attention 窗口,在LM-Infinit中就已經(jīng)被提出了,其基本原理和StreamingLLM是一致的。

89312f42-9125-11ee-939d-92fbcf53809c.jpg

2.3 量化與稀疏

該類方法是基于壓縮的思想,通過量化與稀疏壓縮 KV cache 的 顯存消耗。

當前主流推理框架都在逐步支持 KV cache 量化,一個典型的案例是lmdeploy,下圖展示了其在TurboMind框架下 KV INT8 的支持情況。

893c6b6e-9125-11ee-939d-92fbcf53809c.jpg

lmdeploy 的推理特性

稀疏的方法也比較簡單,其做法無外乎以下幾種方式:

894638b0-9125-11ee-939d-92fbcf53809c.jpg

這里最值得一提的是H2O。簡單來說就是通過動態(tài)的評價方式來判斷需要保留和廢棄的KV值,其評估的算法如下所示:

895912a0-9125-11ee-939d-92fbcf53809c.jpg

結(jié)果顯示,在 KV cache 稀疏到只有原來的 20% 時仍然可以保持很高的精度。

89688564-9125-11ee-939d-92fbcf53809c.jpg

2.4 存儲與計算優(yōu)化

該方法的典型代表即vLLM的 PagedAttention,簡單來說就是允許在非連續(xù)的內(nèi)存空間中存儲連續(xù)的 K 和 V。詳情可參考筆者之前的文章,在此不予贅述

FlashDecoding 是在 FlashAttention 的基礎(chǔ)上針對 inference 的優(yōu)化主要分為三步:

長文本下將KV分成更小且方便并行的chunk

對每個chunk的KV,Q和他們進行之前一樣的FlashAttention獲取這個chunk的結(jié)果

對每個chunk的結(jié)果進行reduce

8977e086-9125-11ee-939d-92fbcf53809c.gif

三、StreamingLLM:簡潔高效的“無限長度”

StreamingLLM 的基本思想同樣是來源于上述的窗口思想,其最大的創(chuàng)新在于提出了識別并保存模型固有的「注意力池」(attention sinks)錨定其推理的初始 token。下面將詳細討論其工作的原理。

3.1 精度是如何保證的?

核心的發(fā)現(xiàn):Lost in the Middle。

多個研究都發(fā)現(xiàn),self-attention 的注意力比較集中于頭部和尾部,對文本中段的注意力相對較弱,如下圖所示:

89ac0e4c-9125-11ee-939d-92fbcf53809c.jpg

繪制出 self-attention 的熱力圖也能看到這一點,由此當文本長度超過額定長度時,頭部的 token 就會被遺棄掉,這就會在 softmax 階段產(chǎn)生很大的問題。

89b64c0e-9125-11ee-939d-92fbcf53809c.jpg

89ce455c-9125-11ee-939d-92fbcf53809c.png

89d52ad4-9125-11ee-939d-92fbcf53809c.png

3.2 “無限長度”是如何做到的?

該問實際上可以換種表述為:如何在文本長度不斷增加的情況下,保證GPU顯存不會溢出。由于該方案主要應用于多輪對話的場景,那么有必要回顧一下當前多輪對話生成的主流做法,概括起來就以下幾點:

將用戶輸入與模型輸出拼接,中間做必要分割;

多個輪次之間倒序排列,并拼接;

如果前邊所有輪次長度之和超過最大長度,則截斷到最大長度;

上述過程可以用代碼描述如下:

  history = ["
[|Human|]{}
[|AI|]{}".format(x[0], x[1]) for x in history]
  history.append("
[|Human|]{}
[|AI|]".format(text))
  history_text = ""
  flag = False
  for x in history[::-1]:
    if tokenizer(prompt + history_text + x, return_tensors="pt")["input_ids"].size(-1) <= max_length:
 ? ? ? ? ? ?history_text = x + history_text
 ? ? ? ? ? ?flag = True
 ? ? ? ?else:
 ? ? ? ? ? ?break
 ? ?if flag:
 ? ? ? ?inputs = tokenizer(prompt + history_text, return_tensors="pt")
 ? ? ? ?input_ids = inputs["input_ids"][:, -max_length:].to(device)
 ? ? ? ?torch.cuda.empty_cache()
 ? ? ? ?return input_ids, text
 ? ?else:
 ? ? ? ?return None

實際上這就是典型的滑動窗口的做法,滑窗?的存在保證了 GPU 的顯存不會溢出,但是由于上節(jié)的討論,會存在精度損失。

89f51d1c-9125-11ee-939d-92fbcf53809c.jpg

8a000696-9125-11ee-939d-92fbcf53809c.png

審核編輯:黃飛

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場。文章及其配圖僅供工程師學習之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問題,請聯(lián)系本站處理。 舉報投訴
  • gpu
    gpu
    +關(guān)注

    關(guān)注

    27

    文章

    4591

    瀏覽量

    128144
  • GPT
    GPT
    +關(guān)注

    關(guān)注

    0

    文章

    347

    瀏覽量

    15182
  • LLM
    LLM
    +關(guān)注

    關(guān)注

    0

    文章

    247

    瀏覽量

    279

原文標題:漫談 KV Cache 優(yōu)化方法,深度理解 StreamingLLM

文章出處:【微信號:GiantPandaCV,微信公眾號:GiantPandaCV】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。

收藏 人收藏

    評論

    相關(guān)推薦

    內(nèi)存分配及Cache優(yōu)化

    C6000的芯片支持庫CSL中的CACHE-setL2Mode函數(shù),將L2設置為198KB的SRAM和64KB的Cache模式。并根據(jù)H.264算法本身的結(jié)構(gòu),采取以下方法對存儲器進行優(yōu)化
    發(fā)表于 08-10 14:54

    如何理解C6678中關(guān)于cache的描述?

    在TMS320C6678中,有這樣對cache的描述:“L1D memory cannot be cached within L1D cache, L1P cache, or L2 cache
    發(fā)表于 06-21 16:07

    請教關(guān)于EDMA和cache優(yōu)化的疑惑

    hi,everyone:經(jīng)??吹骄W(wǎng)上說,EDMA算法優(yōu)化,在片上L2SRAM 中開辟內(nèi)存,將片外數(shù)據(jù)從DDR或SDRAM 利用EDMA搬運到L2SRAM中。但是, 我有兩點疑惑:1.我覺得這種方法
    發(fā)表于 07-27 09:38

    使用CACHE_disableCaching函數(shù)禁止cache沒起作用

    CACHE_getMemRegionInfo (129, &pcx, &pfx); 讀取pcx的值 仍然是1,所以沒起作用。懷疑是當前模式是user mode,而修改MAR寄存器需要
    發(fā)表于 12-28 11:12

    Cache為什么還要分I-Cache,D-Cache,L2 Cache,作用是什么?

    Cache為什么還要分I-Cache,D-Cache,L2 Cache,作用是什么?
    發(fā)表于 10-25 06:38

    基于修正LRU的壓縮Cache替換策略

    優(yōu)化壓縮cache的替換策略為目標,提出一種優(yōu)化的基于修正LRU的壓縮cache替換策略MLRU-C。MLRU-C策略能利用壓縮cache
    發(fā)表于 04-15 09:51 ?36次下載

    Cache中Tag電路的設計

    摘要:在SoC系統(tǒng)中,片上緩存(Cache)的采用是解決片上處理器和片外存儲器之間速度差異的重要方法Cache中用來存儲標記位并判斷Cache是否命中的Tag電路的設計將會影響到整個
    發(fā)表于 05-08 09:26 ?11次下載

    降低Cache失效率的方法[1]

    降低Cache失效率的方法[1]  學習目標:     理解失效的三種類型(3C);
    發(fā)表于 04-13 16:32 ?4186次閱讀

    降低Cache失效率的方法[2]

    降低Cache失效率的方法[2] 表4.7列出了在這兩種極端情況之間的各種塊大小和各種 Cache 容量的平均訪存時間。速度最快的情況: Cache 容量為1KB、4KB、1
    發(fā)表于 04-13 16:33 ?4829次閱讀

    一種有效的Cache優(yōu)化替換策略

    該問題,一種有效的解決方法優(yōu)化Cache替換策略,減少Cache中臟塊被替換出的數(shù)量?,F(xiàn)有研究主要通過在插入和訪問命中時給臟塊設定較高的保護優(yōu)先級來達到給臟塊額外保護的目的,但是在降
    發(fā)表于 11-27 15:16 ?1次下載
    一種有效的<b class='flag-5'>Cache</b><b class='flag-5'>優(yōu)化</b>替換策略

    Page Cache是什么 一文帶你深入理解Linux的Page Cache

    是什么? 為了理解 Page Cache,我們不妨先看一下 Linux 的文件 I/O 系統(tǒng),如下圖所示: Figure1. Linux 文件 I/O 系統(tǒng) 上圖中,紅色部分為 Page Cache??梢?Page
    的頭像 發(fā)表于 10-20 14:12 ?5754次閱讀
    Page <b class='flag-5'>Cache</b>是什么 一文帶你深入<b class='flag-5'>理解</b>Linux的Page <b class='flag-5'>Cache</b>

    什么是 Cache? Cache讀寫原理

    由于寫入數(shù)據(jù)和讀取指令分別通過 D-Cache 和 I-Cache,所以需要同步 D-Cache 和 I-Cache,即復制后需要先將 D-Cach
    發(fā)表于 12-06 09:55 ?1928次閱讀

    Cache與性能優(yōu)化精彩問答38條

    占用非常大的面積,大概在一半以上,而且一個好的 Cache 的設計復雜度非常高,可能比較 CPU 的 Pipeline 還要復雜。這里要考慮成本,設計復雜度,或者其他方面的考慮。你知道 L1
    的頭像 發(fā)表于 01-11 09:34 ?1138次閱讀

    深入理解Cache工作原理

    按照數(shù)據(jù)關(guān)系劃分:Inclusive/exclusive Cache: 下級Cache包含上級的數(shù)據(jù)叫inclusive Cache。不包含叫exclusive Cache。舉個例子,
    的頭像 發(fā)表于 05-30 16:02 ?672次閱讀
    深入<b class='flag-5'>理解</b><b class='flag-5'>Cache</b>工作原理

    Cache分類與替換算法

    根據(jù)不同的分類標準可以按以下3種方法Cache進行分類。 ?1)數(shù)據(jù)cache和指令cache ?● 指令cache:指令預取時使用的
    的頭像 發(fā)表于 10-31 11:26 ?765次閱讀
    <b class='flag-5'>Cache</b>分類與替換算法