LLM 推理加速技巧 KVCache 解析
💡 最近李宏毅老师更新了关于 KVCache 的教学视频,正好趁这个机会整理了下 KVCache 原理相关内容。
核心原理:transformer 架构是自回归的,在计算 Attention 是会用到过去的 key 和 value,缓存这些值通过空间换时间的方式,提高算力计算效率。
原理分析
简单介绍下 Transformer 中的 Attention 层的计算过程:
准备阶段:
- 输入的句子通过 tokenizer 进行分词
- 对每个 token 进行 embedding 得到 embedding 向量
这里得到 embedding 向量,作为 Attention 层的输入向量 x。注意这个 x 向量,可以是任意类型层的输出,不一定是 Embedding 层的输出。
计算阶段
- x 向量乘上预先训练好的权重矩阵 Wq、Wk、Wv 得到 q、k、v 向量
- 使用 attention 公式计算得到 z 向量
以下图为例:
假设最开始我们只有 Token 1:
- 通过预先训练的好的权重矩阵可以得到 Query Token 1、Key Token 1 和 Value Token 1
- 使用 Attention 公式,计算得到 Token 1 对应的 Attention
- 通过 softmax 层并全连接到词表得到 Token 2
得到的 Token 2 再走相同的方式进行计算,从图中下半部分可以发现,计算过程中,会用到 Key Token 1 和 Value Token 1。
而所谓的 KVCache 就是将计算过的 Key、Value 在显存里缓存起来(紫色部分)。如果不进行缓存,那么每次历史的 K,V 向量都需要重新计算。
代码实现
调整 Attention 层的 forword 方法,增加 cache 参数,输入参数 x 也会根据不同的阶段做调整,其中 x 向量的形状(Batch, Time/Seq_len, Channels):
- 在 Prefill 阶段,由于所有的 token 都是已知的,无需 cache,参数 x 对应的是完整的输入向量,直接进行 attention 计算
- 在 Decode 阶段,参数 x 为前一次新生成 token,cache 为 kv 缓存值。
KVCache 大小估算
在实际自回归模型中,特别是最近发布的一些推理模型中,往往串联了很多的 Attention 层。提前估算 KVCache 的大小,有助于在时间(算力)和空间(显存)之间寻找平衡,提升经济效益。
从上面的代码中也不难发现,kvcache 是以 [(k, v), (k, v)] 的形式存储,如果有 N 个 Attention 层,那么就会有 n 个这样的元组,其中 k 和 v 的形状均为 [b, n_head, s, head_dims]。那么对应 cache 的数据量就是:
2 * b * n_head * s * head_dims * N
以 Qwen2.5-72B 为例,先通过 得到 hidden_size、num_hidden_layers、num_attention_heads、num_key_value_heads、max_position_embeddings。
- num_attention_heads 是注意力头数量,对应的是 n_head,数值 64
- hidden_size 是隐藏层的特征维度大小,需要除以注意力头的数量,得到 head_dims,数值:8192/64
- num_hidden_layers 是隐藏层数量,对应的是 N,数值 80
- max_position_embeddings 是最大长度,对应的是 s,数值 32768
如果 kv 值以 bf16 精度存储,并且只处理一个请求的话,一次最大长度的推理,kvcache 的大小计算如下:
(2 * b * n_head * s * head_dims * N) * 2 \
= 2 * 1 * 64 * 32768 * 8192/64 * 80 * 2 \
= 80G
🤯 但是这里有个坑,config.json 中有一个 num_key_value_heads,说明这个模型使用了 GQA 组查询注意力,因此,需要使用 num_key_value_heads 作为 n_head 即 8 个。所以整体的 kvcache 大小应该是 10GB。
💡
**num_attention_heads**对应 Query (Q) 的头数,**num_key_value_heads**** ** 对应 Key (K) 和 Value (V) 的头数
Paged Attention
Paged Attention 算法:将 attention 算法产生的连续的 key value 向量按照非连续的 block 进行组织和管理,以减少显存碎片。同时,对已经计算过的 Block 进行 哈希(Hash)缓存,当新请求的前缀 hash 值,与缓存命中,则可以跳过 prefill 阶段。
GQA(Grouped-Query Attention)
将 Q 划分为若干个组,每个组内的多个 Q 共享一组 K 和 V。