大模型中的KV Cache

1. KV Cache的定义与核心原理

KV Cache(Key-Value Cache)是一种在Transformer架构的大模型推理阶段使用的优化技术,通过缓存自注意力机制中的键(Key)和值(Value)矩阵,避免重复计算,从而显著提升推理效率。

原理:

  • 自注意力机制:在Transformer中,注意力计算基于公式:
    Attention ( Q , K , V ) = softmax ( Q K ⊤ d k ) V = ∑ i = 1 n w i v i (加权求和形式) \begin{split} \text{Attention}(Q, K, V) &= \text{softmax}\left( \frac{QK^\top}{\sqrt{d_k}} \right) V \\ &= \sum_{i=1}^n w_i v_i \quad \text{(加权求和形式)} \end{split} Attention(Q,K,V)=softmax(dk QK)V=i=1nwivi(加权求和形式)
    其中,Q(Query)、K(Key)、V(Value)由输入序列线性变换得到。

  • 缓存机制:在生成式任务(如文本生成)中,模型以自回归方式逐个生成token。首次推理时,计算所有输入token的K和V并缓存;后续生成时,仅需为新token计算Q,并从缓存中读取历史K和V进行注意力计算。

  • 复杂度优化:传统方法的计算复杂度为O(n²),而KV Cache将后续生成的复杂度降为O(n),避免重复计算历史token的K和V。

2. KV Cache的核心作用

  • 加速推理:通过复用缓存的K和V,减少矩阵计算量,提升生成速度。例如,某聊天机器人应用响应时间从0.5秒缩短至0.2秒。

  • 降低资源消耗:显存占用减少约30%-50%(例如移动端模型从1GB降至0.6GB),支持在资源受限设备上部署大模型。

  • 支持长文本生成:缓存机制使推理耗时不再随文本长度线性增长,可稳定处理长序列(如1024 token以上)。

  • 保持模型性能:仅优化计算流程,不影响输出质量。

3. 技术实现与优化策略

实现方式:
  • 数据结构

    1. KV Cache以张量形式存储,Key Cache和Value Cache的形状分别为(batch_size, num_heads, seq_len, k_dim)(batch_size, num_heads, seq_len, v_dim)
  • 两阶段推理:

    1. 初始化阶段:计算初始输入的所有K和V,存入缓存。
    2. 迭代阶段:仅计算新token的Q,结合缓存中的K和V生成输出,并更新缓存。
      • 代码示例(Hugging Face Transformers):设置model.generate(use_cache=True)即可启用KV Cache。

优化策略:

  • 稀疏化(Sparse):仅缓存部分重要K和V,减少显存占用。

  • 量化(Quantization):将K和V矩阵从FP32转为INT8/INT4,降低存储需求。

共享机制(MQA/GQA):

  • Multi-Query Attention (MQA):所有注意力头共享同一组K和V,显存占用降低至1/头数。

  • Grouped-Query Attention (GQA):将头分组,组内共享K和V,平衡性能和显存。

4. 挑战与局限性

  • 显存压力:随着序列长度增加,缓存占用显存线性增长(如1024 token占用约1GB显存),可能引发OOM(内存溢出)。

  • 冷启动问题:首次推理仍需完整计算K和V,无法完全避免初始延迟。

5、python实现

import torch
import torch.nn as nn# 超参数
d_model = 4
n_heads = 1
seq_len = 3
batch_size = 3# 初始化参数(兼容多头形式)
Wq = nn.Linear(d_model, d_model, bias=False)
Wk = nn.Linear(d_model, d_model, bias=False)
Wv = nn.Linear(d_model, d_model, bias=False)# 生成模拟输入(整个序列一次性输入)
input_sequence = torch.randn(batch_size, seq_len, d_model)  # [B, L, D]# 初始化 KV 缓存(兼容多头格式)
kv_cache = {"keys": torch.empty(batch_size, 0, n_heads, d_model // n_heads),  # [B, T, H, D/H]"values": torch.empty(batch_size, 0, n_heads, d_model // n_heads) 
}# 因果掩码预先生成(覆盖最大序列长度)
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()  # [L, L]'''
本循环是将整句话中的token一个一个输入,并更新KV缓存;
所以无需显示的因果掩码,因为因果掩码只用于计算注意力权重时,而计算注意力权重时,KV缓存中的key和value已经包含了因果掩码的信息。'''for step in range(seq_len):# 1. 获取当前时间步的输入(整个批次)current_token = input_sequence[:, step, :]  # [B, 1, D]# 2. 计算当前时间步的 Q/K/V(保持三维结构)q = Wq(current_token)  # [B, 1, D]k = Wk(current_token)  # [B, 1, D]v = Wv(current_token)  # [B, 1, D]# 3. 调整维度以兼容多头格式(关键修改点)def reshape_for_multihead(x):return x.view(batch_size, 1, n_heads, d_model // n_heads).transpose(1, 2)  # [B, H, 1, D/H]# 4. 更新 KV 缓存(增加时间步维度)kv_cache["keys"] = torch.cat([kv_cache["keys"], reshape_for_multihead(k).transpose(1, 2)  # [B, T+1, H, D/H]], dim=1)kv_cache["values"] = torch.cat([kv_cache["values"],reshape_for_multihead(v).transpose(1, 2)  # [B, T+1, H, D/H]], dim=1)# 5. 多头注意力计算(支持批量处理)q_multi = reshape_for_multihead(q)  # [B, H, 1, D/H]k_multi = kv_cache["keys"].transpose(1, 2)  # [B, H, T+1, D/H]print("q_multi shape:", q_multi.shape)print("k_multi shape:", k_multi.shape)# 6. 计算注意力分数(带因果掩码)attn_scores = torch.matmul(q_multi, k_multi.transpose(-2, -1)) / (d_model ** 0.5)print("attn_scores shape:", attn_scores.shape)# attn_scores = attn_scores.masked_fill(causal_mask[:step+1, :step+1], float('-inf'))# print("attn_scores shape:", attn_scores.shape)# 7. 注意力权重计算attn_weights = torch.softmax(attn_scores, dim=-1)  # [B, H, 1, T+1]# 8. 加权求和output = torch.matmul(attn_weights, kv_cache["values"].transpose(1, 2))  # [B, H, 1, D/H]# 9. 合并多头输出output = output.contiguous().view(batch_size, 1, d_model)  # [B, 1, D]print(f"Step {step} 输出:", output.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/905318.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Maven 公司内部私服中央仓库搭建 局域网仓库 资源共享 依赖包构建共享

介绍 公司内部私服搭建通常是为了更好地管理公司内部的依赖包和构建过程,避免直接使用外部 Maven 中央仓库。通过搭建私服,团队能够控制依赖的版本、提高构建速度并增强安全性。公司开发的一些公共工具库更换的提供给内部使用。 私服是一种特殊的远程仓…

文档外发安全:企业数据防护的最后一道防线

在当今数字化时代,数据已成为企业最宝贵的资产之一。随着网络安全威胁日益增多,企业安装专业加密软件已从"可选"变为"必选"。本文将全面分析企业部署华途加密解决方案后获得的各项战略优势。 一、数据安全防护升级 核心数据全面保护…

【ArcGIS】根据shp范围生成系列等距点:范围外等距点+渔网点(Python全代码)

【ArcGIS】根据shp范围生成系列等距点 目标1:生成边界外一定范围、并且等间距分布的点📁 所需数据:操作步骤-ArcGIS代码处理-Python 目标2:生成等距渔网点📁 所需数据:代码处理-Python 参考 目标1&#xff…

Docker 环境安装(2025最新版)

Docker在主流的操作系统和云平台上都可以使用,包括Linux操作 系统(如Ubuntu、 Debian、Rocky、Redhat等)、MacOS操作系统和 Windows操作系统,以及AWS等云平 台。 Docker官网: https://docs.docker.com/ 配置宿主机网…

Java并发编程-线程池(二)

文章目录 线程池的实现原理execute(Runnable command)**1. 阶段一:尝试创建核心线程****2. 阶段二:尝试将任务加入队列****3. 阶段三:尝试创建非核心线程或拒绝任务****关键机制与设计思想** 线程池的实现原理 当向线程池提交一个任务之后&a…

清华大学开源软件镜像站地址

清华大学开源软件镜像站: https://mirrors.tuna.tsinghua.edu.cn/

脑机接口技术:开启人类与机器融合的新时代

摘要 脑机接口(BCI)技术作为一项前沿科技,正在逐步打破人类与机器之间的沟通障碍,为医疗、娱乐、教育等多个领域带来前所未有的变革。本文将详细介绍脑机接口技术的基本原理、发展现状、应用场景以及面临的挑战和未来发展趋势&…

2025前端面试遇到的问题(vue+uniapp+js+css)

Vue相关面试题 vue2和vue3的区别 一、核心架构差异 特性Vue2Vue3响应式系统基于Object.defineProperty基于Proxy(支持动态新增/删除属性)代码组织方式Options API(data/methods分块)Composition API(逻辑按功能聚合&am…

Matlab基于SSA-MVMD麻雀算法优化多元变分模态分解

Matlab基于SSA-MVMD麻雀算法优化多元变分模态分解 目录 Matlab基于SSA-MVMD麻雀算法优化多元变分模态分解效果一览基本介绍程序设计参考资料效果一览 基本介绍 Matlab基于SSA-MVMD麻雀算法优化多元变分模态分解 可直接运行 分解效果好 适合作为创新点(Matlab完整源码和数据),…

Gatsby知识框架

一、Gatsby 基础概念 1. 核心特性 基于React的静态站点生成器:使用React构建,输出静态HTML/CSS/JS GraphQL数据层:统一的数据查询接口 丰富的插件系统:超过2000个官方和社区插件 高性能优化:自动代码分割、预加载、…

【论信息系统项目的资源管理】

论信息系统项目的资源管理 前言一、规划好资源管理,为保证项目完成做好人员规划二、估算活动资源,为制订项目进度计划提供资源需求三、获取项目资源,组建一个完备的项目团队四、建设项目团队,提高工作能力,促进团队成员…

idea Maven 打包SpringBoot可执行的jar包

背景&#xff1a;当我们需要坐联调测试的时候&#xff0c;需要对接前端同事&#xff0c;则需要打包成jar包直接运行启动服务 需要将项目中的pom文件增加如下代码配置&#xff1a; <build><plugins><plugin><groupId>org.springframework.boot</gr…

VScode 的插件本地更改后怎么生效

首先 vscode 的插件安装地址为 C:\Users\%USERNAME%\.vscode\extensions 找到你的插件包进行更改 想要打印日志&#xff0c;用下面方法 vscode.window.showErrorMessage(console.log "${name}" exists.); 打印结果 找到插件&#xff0c;点击卸载 然后点击重新启动 …

Python训练营打卡——DAY24(2025.5.13)

目录 一、元组 1. 通俗解释 2. 元组的特点 3. 元组的创建 4. 元组的常见用法 二、可迭代对象 1. 定义 2. 示例 3. 通俗解释 三、OS 模块 1. 通俗解释 2. 目录树 四、作业 1. 准备工作 2. 实战代码示例​ 3. 重要概念解析 一、元组 是什么​​&#xff1a;一种…

初入OpenCV

OpenCV简介 OpenCV是一个开源的跨平台计算机视觉库&#xff0c;它实现了图像处理和计算机视觉方面的很多通用算法。 应用场景&#xff1a; 目标识别&#xff1a;人脸、车辆、车牌、动物&#xff1b; 自动驾驶&#xff1b;医学影像分析&#xff1b; 视频内容理解分析&#xff…

讯联云库项目开发日志(一)

1、设计数据库 2、写基本框架 entity、controller、service、exception、utils、mapper mapper层&#xff1a; 生成了一系列的CRUD方法 工具类&#xff1a;线程安全的日期工具类、 ​​参数校验工具类​ 线程安全的日期工具类​​&#xff1a;主要用于 ​​日期格式化&…

langchain学习

无门槛免费申请OpenAI ChatGPT API搭建自己的ChatGPT聊天工具 https://nuowa.net/872 基本概念 LangChain 能解决大模型的两个痛点&#xff0c;包括模型接口复杂、输入长度受限离不开自己精心设计的模块。根据LangChain 的最新文档&#xff0c;目前在 LangChain 中一共有六大…

Protobuf工具

#region 知识点一 什么是 Protobuf //Protobuf 全称是 protocol - buffers&#xff08;协议缓冲区&#xff09; // 是谷歌提供给开发者的一个开源的协议生成工具 // 它的主要工作原理和我们之前做的自定义协议工具类似 // 只不过它更加的完善&…

zst-2001 上午题-历年真题 软件工程(38个内容)

CMM 软件工程 - 第1题 b 软件工程 - 第2题 c 软件工程 - 第3题 c 软件工程 - 第4题 b 软件工程 - 第5题 b CMMI 软件工程 - 第6题 0.未完成&#xff1a;未执行未得到目标。1.已执行&#xff1a;输入-输出实现支持2.已管理&#xff1a;过程制度化&#x…

软考架构师考试-UML图总结

考点 选择题 2-4分 案例分析0~1题和面向对象结合考察&#xff0c;前几年固定一题。近3次考试没有出现。但还是有可能考。 UML图概述 1.用例图&#xff1a;描述系统功能需求和用户&#xff08;参与者&#xff09;与系统之间的交互关系&#xff0c;聚焦于“做什么”。 2.类图&…