实用指南:LLMs-from-scratch :KV 缓存

news/2025/11/10 18:00:29/文章来源:https://www.cnblogs.com/ljbguanli/p/19207840

实用指南:LLMs-from-scratch :KV 缓存

原文链接:https://github.com/rasbt/LLMs-from-scratch/tree/main/ch04/03_kv-cache

概述

简而言之,KV 缓存存储中间的键(K)和值(V)计算结果以便在推理过程中重复使用,这能在生成响应时带来显著的速度提升。缺点是会增加代码复杂性,增加内存使用量,并且不能在训练过程中使用。然而,在部署大语言模型时,推理速度的提升通常值得在代码复杂性和内存方面做出权衡。

工作原理

想象一下大语言模型正在生成一些文本。具体来说,假设大语言模型收到以下提示:“Time flies”。

下图显示了底层注意力分数计算的摘录,使用了第3章的修改图形,其中突出显示了键和值向量:

现在,正如我们在第2章和第4章中学到的,大语言模型一次生成一个词(或标记)。假设大语言模型生成了单词"fast",那么下一轮的提示就变成了"Time flies fast"。这在下图中进行了说明:

正如我们通过比较前面两个图所看到的,前两个标记的键和值向量完全相同,在每一轮的下一个标记文本生成中重新计算它们是浪费的。

因此,KV 缓存的想法是实现一个缓存机制,存储先前生成的键和值向量以供重复使用,这有助于我们避免不必要的重新计算。

KV 缓存实现

有许多方法可以实现 KV 缓存,主要思想是我们只为每个生成步骤中新生成的标记计算键和值张量。

我选择了一个强调代码可读性的简单方法。我认为最简单的方法就是浏览代码更改来了解它是如何实现的。

本文件夹中有两个文件:

  1. gpt_ch04.py:从第3章和第4章中提取的自包含代码,用于实现大语言模型并运行简单的文本生成函数
  2. gpt_with_kv_cache.py:与上面相同,但进行了必要的更改以实现 KV 缓存。

你可以选择:

a. 打开 gpt_with_kv_cache.py 文件并查找标记新更改的 # NEW 部分:

b. 通过你选择的文件差异工具查看两个代码文件以比较更改:

为了总结实现细节,这里是一个简短的演练。

1. 注册缓存缓冲区

MultiHeadAttention 构造函数内部,我们添加两个非持久缓冲区 cache_kcache_v,它们将保存跨步骤连接的键和值:

self.register_buffer("cache_k", None, persistent=False)
self.register_buffer("cache_v", None, persistent=False)

2. 带有 use_cache 标志的前向传播

接下来,我们扩展 MultiHeadAttention 类的 forward 方法以接受 use_cache 参数。在将新的标记块投影到 keys_newvalues_newqueries 之后,我们要么初始化 kv 缓存,要么追加到我们的缓存中:

def forward(self, x, use_cache=False):
b, num_tokens, d_in = x.shape
keys_new = self.W_key(x)  # Shape: (b, num_tokens, d_out)
values_new = self.W_value(x)
queries = self.W_query(x)
#...
if use_cache:
if self.cache_k is None:
self.cache_k, self.cache_v = keys_new, values_new
else:
self.cache_k = torch.cat([self.cache_k, keys_new], dim=1)
self.cache_v = torch.cat([self.cache_v, values_new], dim=1)
keys, values = self.cache_k, self.cache_v
else:
keys, values = keys_new, values_new
# ...
num_tokens_Q = queries.shape[-2]
num_tokens_K = keys.shape[-2]
if use_cache:
mask_bool = self.mask.bool()[
self.ptr_current_pos:self.ptr_current_pos + num_tokens_Q, :num_tokens_K
]
self.ptr_current_pos += num_tokens_Q
else:
mask_bool = self.mask.bool()[:num_tokens_Q, :num_tokens_K]

3. 清除缓存

在生成文本时,在独立序列之间(例如两次文本生成调用之间),我们必须重置两个缓冲区,因此我们还向 MultiHeadAttention 类添加了一个缓存重置方法:

def reset_cache(self):
self.cache_k, self.cache_v = None, None
self.ptr_current_pos = 0

4. 在完整模型中传播 use_cache

在对 MultiHeadAttention 类进行更改后,我们现在修改 GPTModel 类。首先,我们在构造函数中为标记索引添加位置跟踪:

self.current_pos = 0

然后,我们用显式循环替换单行块调用,通过每个变换器块传递 use_cache

def forward(self, in_idx, use_cache=False):
# ...
if use_cache:
pos_ids = torch.arange(
self.current_pos, self.current_pos + seq_len,
device=in_idx.device, dtype=torch.long
)
self.current_pos += seq_len
else:
pos_ids = torch.arange(
0, seq_len, device=in_idx.device, dtype=torch.long
)
pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
x = tok_embeds + pos_embeds
# ...
for blk in self.trf_blocks:
x = blk(x, use_cache=use_cache)

上述更改还需要对 TransformerBlock 类进行小的修改以接受 use_cache 参数:

def forward(self, x, use_cache=False):
# ...
self.att(x, use_cache=use_cache)

最后,我们向 GPTModel 添加模型级重置,以便一次清除所有块缓存:

def reset_kv_cache(self):
for blk in self.trf_blocks:
blk.att.reset_cache()
self.current_pos = 0

5. 在生成中使用缓存

通过对 GPTModelTransformerBlockMultiHeadAttention 的更改,最后,这是我们如何在简单的文本生成函数中使用 KV 缓存:

def generate_text_simple_cached(model, idx, max_new_tokens,
context_size=None, use_cache=True):
model.eval()
ctx_len = context_size or model.pos_emb.num_embeddings
with torch.no_grad():
if use_cache:
# 用完整提示初始化缓存
model.reset_kv_cache()
logits = model(idx[:, -ctx_len:], use_cache=True)
for _ in range(max_new_tokens):
# a) 选择具有最高对数概率的标记(贪婪采样)
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
# b) 将其追加到运行序列中
idx = torch.cat([idx, next_idx], dim=1)
# c) 只向模型提供新标记
logits = model(next_idx, use_cache=True)
else:
for _ in range(max_new_tokens):
logits = model(idx[:, -ctx_len:], use_cache=False)
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
idx = torch.cat([idx, next_idx], dim=1)
return idx

注意,我们在 c) 中只通过 logits = model(next_idx, use_cache=True) 向模型提供新标记。没有缓存时,我们向模型提供整个输入 logits = model(idx[:, -ctx_len:], use_cache=False),因为它没有存储的键和值可以重复使用。

简单性能比较

在概念层面介绍了 KV 缓存之后,最大的问题是它在小例子的实际应用中表现如何。为了试用这个实现,我们可以将前面提到的两个代码文件作为 Python 脚本运行,这将运行小型 1.24 亿参数大语言模型来生成 200 个新标记(给定 4 个标记的提示"Hello, I am"开始):

pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/requirements.txt
python gpt_ch04.py
python gpt_with_kv_cache.py

在配备 M4 芯片的 Mac Mini(CPU)上,结果如下:

标记/秒
gpt_ch04.py27
gpt_with_kv_cache.py144

因此,正如我们所看到的,我们已经在小型 1.24 亿参数模型和短 200 标记序列长度下获得了约 5 倍的速度提升。(注意,这个实现针对代码可读性进行了优化,而不是针对 CUDA 或 MPS 运行时速度进行了优化,后者需要预分配张量而不是重新实例化和连接它们。)

注意: 在两种情况下,模型都生成"胡言乱语",即看起来像这样的文本:

输出文本:Hello, I am Featureiman Byeswickattribute argue logger Normandy Compton analogous bore ITVEGIN ministriesysics Kle functional recountrictionchangingVirgin embarrassedgl …

这是因为我们还没有训练模型。下一章训练模型,你可以在训练好的模型上使用 KV 缓存(但是,KV 缓存只能在推理期间使用)来生成连贯的文本。在这里,我们使用未训练的模型来保持代码简单。

不过,更重要的是,gpt_ch04.pygpt_with_kv_cache.py 实现产生完全相同的文本。这告诉我们 KV 缓存实现是正确的——很容易犯索引错误,这可能导致不同的结果。

KV 缓存的优缺点

随着序列长度的增加,KV 缓存的好处和缺点在以下方面变得更加明显:

  • [好处] 计算效率提高:没有缓存时,步骤 t 的注意力必须将新查询与 t 个先前的键进行比较,因此累积工作量呈二次方增长,O(n²)。有了缓存,每个键和值只计算一次然后重复使用,将总的每步复杂度降低到线性,O(n)。

  • [缺点] 内存使用量线性增长:每个新标记都会追加到 KV 缓存中。对于长序列和更大的大语言模型,累积的 KV 缓存会变得更大,这可能消耗大量甚至令人望而却步的(GPU)内存。作为解决方法,我们可以截断 KV 缓存,但这会增加更多复杂性(但同样,在部署大语言模型时这可能是值得的。)

优化 KV 缓存实现

虽然我上面的 KV 缓存概念实现有助于清晰理解,主要面向代码可读性和教育目的,但在实际场景中部署它(特别是对于更大的模型和更长的序列长度)需要更仔细的优化。

扩展缓存时的常见陷阱

  • 内存碎片和重复分配:如前所示,通过 torch.cat 持续连接张量会导致性能瓶颈,因为频繁的内存分配和重新分配。

  • 内存使用量的线性增长:没有适当的处理,KV 缓存大小对于非常长的序列变得不切实际。

技巧 1:预分配内存

与其重复连接张量,我们可以基于预期的最大序列长度预分配足够大的张量。这确保了一致的内存使用并减少开销。在伪代码中,这可能看起来如下:

# 键和值的预分配示例
max_seq_len = 1024  # 预期的最大序列长度
cache_k = torch.zeros((batch_size, num_heads, max_seq_len, head_dim), device=device)
cache_v = torch.zeros((batch_size, num_heads, max_seq_len, head_dim), device=device)

在推理期间,我们可以简单地写入这些预分配张量的切片。

技巧 2:通过滑动窗口截断缓存

为了避免耗尽我们的 GPU 内存,我们可以实现带有动态截断的滑动窗口方法。通过滑动窗口,我们只在缓存中维护最后 window_size 个标记:

# 滑动窗口缓存实现
window_size = 512
cache_k = cache_k[:, :, -window_size:, :]
cache_v = cache_v[:, :, -window_size:, :]

实际中的优化

你可以在 gpt_with_kv_cache_optimized.py 文件中找到这些优化。

在配备 M4 芯片的 Mac Mini(CPU)上,使用 200 标记生成和等于上下文长度的窗口大小(以保证相同结果),代码运行时间比较如下:

标记/秒
gpt_ch04.py27
gpt_with_kv_cache.py144
gpt_with_kv_cache_optimized.py166

不幸的是,在 CUDA 设备上速度优势消失了,因为这是一个微小的模型,设备传输和通信超过了 KV 缓存对这个小模型的好处。

额外资源

  1. Qwen3 从零开始的 KV 缓存基准测试
  2. Llama 3 从零开始的 KV 缓存基准测试
  3. 从零开始理解和编码大语言模型中的 KV 缓存 – 这个 README 的更详细写作

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/961624.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前置和后置的区别

1、前置式返回的引用,效率会高一点 2、后置式返回的常对象,内部需要一个临时对象,效率相对低一些 备注: 1、前置式和后置式都没有参数,为了加以区分,再后置式增加int自变量,默认值为0 2、对于后置式返回常对象,…

2025年11月太阳能板/光伏板/电池板/单晶硅/多晶硅板前十厂家排名:深圳精益太阳能板领跑行业

文章摘要 本文基于2025年太阳能板行业发展趋势,分析了全球太阳能板市场的竞争格局,重点介绍了前十强品牌的排名、技术优势及服务特点。行业发展迅猛,高效、稳定、环保成为核心需求,本文提供详细排名和品牌信息,并…

TCP报文中的时间戳有什么作用

以上仅供参考,如有疑问,留言联系

响应式编程 - reactor 初识

Reactor 3 是一个围绕该规范构建的库,将响应式编程Reactive Streams范式引入JVM。 在本课程中,你将熟悉 Reactor API。那么,让我们快速介绍一下响应式流和响应式编程中更通用的概念。 package com.qinrenjihe;impor…

ubuntu16.04安装CUDA驱动 - 小

背景:项目需要使用PyTorch ,调用这两个命令nvidia-smi nvcc --version安装cuda,先安装显卡驱动 检查显卡型号:lspci | grep -i nvidia 01:00.0 VGA compatible controller: NVIDIA Corporation GP106 [GeForce …

深入解析:统一高效图像生成与编辑!百度新加坡国立提出Query-Kontext,多项任务“反杀”专用模型

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

2025年11月太阳能板生产厂家排名前十榜单:深圳精益太阳能板引领行业

摘要 随着全球对可再生能源需求的增长,太阳能板行业在2025年持续创新与扩张。本文基于权威市场数据和行业报告,精选出前十名太阳能板生产厂家,重点推荐深圳精益太阳能板作为榜首。榜单结合技术参数、用户口碑和品牌…

reactor 初识

package com.qinrenjihe;import org.jspecify.annotations.NonNull; import reactor.core.publisher.Flux;public class Main {// 创建一个空的 Fluxstatic Flux<@NonNull String> emptyFlux() {return Flux.emp…

QOJ6608 Descent of Dragons

为什么这题放在了 NOIP T2?自闭了……修改只会使值从 \(x\) 变成 \(x+1\),这个对整体的值域变化是非常小的。 对于一个阈值 \(lim\),考虑 \(01\) 序列 \(A_{lim}\),\(A_{lim,i}=[a_i\ge lim]\)。 对于一次修改,实…

2026年HR 数字化转型趋势:AI如何帮助HR从招聘到绩效全流程人效提升 48%?

根据艾瑞咨询 2025 年《中国 HR SaaS 行业研究报告》显示,预计 2025 年国内 HR SaaS 市场规模将突破 240 亿元,其中 AI 技术贡献的价值占比超 60%。这一数据背后,是 AI 正在彻底重塑 HR 全价值链 —— 从招聘的简历…

Windows利用批处理脚本判断端口, 启动tomcat

以下是一个完整的 Windows 批处理脚本,用于检查指定端口是否被占用,并根据结果选择是否启动 Tomcat。如果端口被占用,还可以选择结束占用端口的进程,再启动 Tomcat。批处理脚本代码batch@echo off :: 设置需要检查…

2025最新实测对比:5款热门工程项目管理系统 协同能力与实用体验深度测评

最近花了两个月时间,我们把市面上主流的5款工程项目管理系统都实际用了一遍。 说实话,这个测评做得挺烧脑的,光是测试数据就整理了十几个G。今天就把最真实的体验分享给大家,希望能帮正在选型的工程公司少走点弯路…

2025年双轴拌馅机实力厂家权威推荐榜单:调味料拌馅机/酱菜搅拌机/翻斗式拌馅机源头厂家精选

在食品工业自动化升级与标准化生产需求持续增长的背景下,双轴拌馅机作为肉制品、酱菜、调味品等食品加工的核心设备,其搅拌均匀性与生产效率直接影响产品品质与生产成本。根据食品机械行业数据显示,全球食品搅拌设备…

2025年终绩效,AI面谈系统让沟通效率翻倍,主管再也不用熬夜写总结

“又要准备绩效面谈了,光整理员工半年的绩效数据、目标完成情况就花了 2 天,面谈时还得边聊边记,生怕漏了关键信息,晚上还得熬夜补总结……” 这是很多企业主管在绩效周期内的真实写照。传统绩效面谈往往陷入 “形…

vue实现T型二维表格

图片实现T形2维表,上下滚动,T形左右可以各自水平滚动底部和顶部水平滚动保持一致实现excle复制粘贴T形左右宽度各自撑开代码如下<template><div class="fixed-table-container"ref="tableCo…

antd table 列表树形结构展示

// 原始数据(子节点字段为 subNodes) const rawData = [{key: 1,name: 父节点,subNodes: [{ key: 1-1, name: 子节点 },],}, ];// 转换函数:递归将 subNodes 改为 children const transformData = (data: any) =>…

2025年深圳救护车运转公司权威推荐榜单:正规救护车出租/急救车出租/出租救护车源头公司精选

在医疗服务需求多元化与人口老龄化趋势加速的背景下,深圳救护车运转服务市场正经历着从基础运输向专业化、分级化的转型升级。行业数据显示,社会对非急救转运服务的需求持续上升,尤其是在康复出院、跨省转院、异地就…

对隐式类型转换保持警觉

操作符重载引起的隐式类型转换 缺点:可能导致非预期的函数被调用 解决:以功能对等的另一个函数取代类型转换操作符 举例: class Rational{ public: Rational(int num = 0,int deno = 1):num_(num),deno_(deno){}; o…

es中批量删除数据

创建bulk_delete.json 文件 {"delete":{"_index":"vivian-scene-warn-history","_type":"warnHistory","_id":"5befb3a1b25c4841bca3637efc36a320&…

docker安装mysql/Redis/nacos/minio/es/xxl-job

yum安装jdk yum -y list java*yum install -y java-1.8.0-openjdk.x86_64#检查是否安装成功java -versiondocker安装mysql docker pull mysql:5.8 docker images mkdir -p /home/service/mysql/data mkdir -p /hom…