Contents

LLM 推理加速技巧 KVCache 解析

💡 最近李宏毅老师更新了关于 KVCache 的教学视频,正好趁这个机会整理了下 KVCache 原理相关内容。

核心原理:transformer 架构是自回归的,在计算 Attention 是会用到过去的 key 和 value,缓存这些值通过空间换时间的方式,提高算力计算效率。

原理分析

简单介绍下 Transformer 中的 Attention 层的计算过程:

准备阶段:

  1. 输入的句子通过 tokenizer 进行分词
  2. 对每个 token 进行 embedding 得到 embedding 向量

这里得到 embedding 向量,作为 Attention 层的输入向量 x。注意这个 x 向量,可以是任意类型层的输出,不一定是 Embedding 层的输出。

计算阶段

  1. x 向量乘上预先训练好的权重矩阵 Wq、Wk、Wv 得到 q、k、v 向量
  2. 使用 attention 公式计算得到 z 向量

以下图为例:

假设最开始我们只有 Token 1:

  1. 通过预先训练的好的权重矩阵可以得到 Query Token 1、Key Token 1 和 Value Token 1
  2. 使用 Attention 公式,计算得到 Token 1 对应的 Attention
  3. 通过 softmax 层并全连接到词表得到 Token 2

得到的 Token 2 再走相同的方式进行计算,从图中下半部分可以发现,计算过程中,会用到 Key Token 1 和 Value Token 1。

而所谓的 KVCache 就是将计算过的 Key、Value 在显存里缓存起来(紫色部分)。如果不进行缓存,那么每次历史的 K,V 向量都需要重新计算。

代码实现

调整 Attention 层的 forword 方法,增加 cache 参数,输入参数 x 也会根据不同的阶段做调整,其中 x 向量的形状(Batch, Time/Seq_len, Channels):

  • 在 Prefill 阶段,由于所有的 token 都是已知的,无需 cache,参数 x 对应的是完整的输入向量,直接进行 attention 计算
  • 在 Decode 阶段,参数 x 为前一次新生成 token,cache 为 kv 缓存值。

KVCache 大小估算

在实际自回归模型中,特别是最近发布的一些推理模型中,往往串联了很多的 Attention 层。提前估算 KVCache 的大小,有助于在时间(算力)和空间(显存)之间寻找平衡,提升经济效益。

从上面的代码中也不难发现,kvcache 是以 [(k, v), (k, v)] 的形式存储,如果有 N 个 Attention 层,那么就会有 n 个这样的元组,其中 k 和 v 的形状均为 [b, n_head, s, head_dims]。那么对应 cache 的数据量就是:

2 * b * n_head * s * head_dims * N

以 Qwen2.5-72B 为例,先通过 得到 hidden_size、num_hidden_layers、num_attention_heads、num_key_value_heads、max_position_embeddings。

  • num_attention_heads 是注意力头数量,对应的是 n_head,数值 64
  • hidden_size 是隐藏层的特征维度大小,需要除以注意力头的数量,得到 head_dims,数值:8192/64
  • num_hidden_layers 是隐藏层数量,对应的是 N,数值 80
  • max_position_embeddings 是最大长度,对应的是 s,数值 32768

如果 kv 值以 bf16 精度存储,并且只处理一个请求的话,一次最大长度的推理,kvcache 的大小计算如下:

(2 * b * n_head * s * head_dims * N) * 2 \
= 2 * 1 * 64  * 32768 * 8192/64 * 80 * 2 \
= 80G

🤯 但是这里有个坑,config.json 中有一个 num_key_value_heads,说明这个模型使用了 GQA 组查询注意力,因此,需要使用 num_key_value_heads 作为 n_head 即 8 个。所以整体的 kvcache 大小应该是 10GB。

💡 **num_attention_heads** 对应 Query (Q) 的头数,**num_key_value_heads**** ** 对应 Key (K)Value (V) 的头数

Paged Attention 算法:将 attention 算法产生的连续的 key value 向量按照非连续的 block 进行组织和管理,以减少显存碎片。同时,对已经计算过的 Block 进行 哈希(Hash)缓存,当新请求的前缀 hash 值,与缓存命中,则可以跳过 prefill 阶段。

GQA(Grouped-Query Attention)

将 Q 划分为若干个,每个组内的多个 Q 共享一组 K 和 V。

参考资料