LLM笔记(九)KV缓存(2)

文章目录

    • 1. 背景与动机
    • 2. 不使用 KV Cache 的情形
      • 2.1 矩阵形式展开
      • 2.2 计算复杂度
    • 3. 使用 KV Cache 的优化
      • 3.1 核心思想
      • 3.2 矩阵形式展开
      • 3.3 计算复杂度对比
    • 4. 总结
    • 5. GPT-2 中 KV 缓存的实现分析
      • 5.1 缓存的数据结构与类型
      • 5.2 在注意力机制 (`GPT2Attention`) 中使用缓存
      • 5.3 缓存的更新机制 (`Cache.update`)
      • 5.4 在模型整体 (`GPT2Model`) 的 `forward` 方法中处理
      • 5.5 因果掩码 (Causal Mask) 与 KV 缓存的配合
      • 5.6 支持多种高效的注意力实现
      • 5.7 KV 缓存的完整工作流程 (自回归生成)
        • 5.7.1 初始步骤 (t=0):
        • 5.7.2 后续步骤 (t > 0):
      • KV 缓存的显著优势

  1. 看图学kv 很形象清楚
  2. gpt2源码
  3. 分析transformer模型的参数量、计算量、中间激活、KV cache量化分析了缓存
  4. kv解读

1. 背景与动机

在自回归生成(autoregressive generation)任务中,Transformer 解码器需要在每一步中根据前面已生成的所有 token 重新计算注意力(Attention),这会产生大量重复计算。引入 KV Cache(Key–Value Cache)后,能够将已计算的键值对缓存下来,仅对新增的 Query 进行点乘与加权,从而大幅降低时间与算力开销。

2. 不使用 KV Cache 的情形

2.1 矩阵形式展开

  • 第 1 步(生成第一个 token)

    Q 1 , K 1 , V 1 ∈ R 1 × d Q_1, K_1, V_1 \in \mathbb{R}^{1\times d} Q1,K1,V1R1×d

    A t t e n t i o n 1 = s o f t m a x ( Q 1 K 1 ⊤ d ) , V 1 Attention_1 = \mathrm{softmax}\Bigl(\frac{Q_1 K_1^\top}{\sqrt d}\Bigr),V_1 Attention1=softmax(d Q1K1),V1

  • 第 2 步(生成第二个 token)
    构造全序列的矩阵:
    image.png

    需重算完整注意力矩阵:

    A t t e n t i o n 1 : 2 = s o f t m a x ( Q 1 : 2 K 1 : 2 ⊤ d ) , V 1 : 2 Attention_{1:2} = \mathrm{softmax}\Bigl(\frac{Q_{1:2}K_{1:2}^\top}{\sqrt d}\Bigr),V_{1:2} Attention1:2=softmax(d Q1:2K1:2),V1:2

    计算出一个 2 × 2 2\times 2 2×2 矩阵,但我们只取最后一行作为输出。

  • 第 n 步

    Q 1 : n , K 1 : n , V 1 : n ∈ R n × d , A t t e n t i o n 1 : n = s o f t m a x ( Q 1 : n K 1 : n ⊤ d ) , V 1 : n Q_{1:n},K_{1:n},V_{1:n}\in\mathbb{R}^{n\times d},\quad Attention_{1:n} = \mathrm{softmax}\Bigl(\tfrac{Q_{1:n}K_{1:n}^\top}{\sqrt d}\Bigr),V_{1:n} Q1:n,K1:n,V1:nRn×d,Attention1:n=softmax(d Q1:nK1:n),V1:n

    每步均重新构建并计算 n × n n\times n n×n 注意力矩阵。

2.2 计算复杂度

  • 注意力矩阵构建 O ( n 2 ⋅ d ) O(n^2\cdot d) O(n2d)

  • 整体推理阶段:若生成总长度为 N N N,则总复杂度近似为

    ∑ n = 1 N O ( n 2 d ) ; = ; O ( N 3 d ) \sum_{n=1}^N O(n^2 d);=;O(N^3 d) n=1NO(n2d);=;O(N3d),

    由于每步都做重复计算,效率极低。

3. 使用 KV Cache 的优化

3.1 核心思想

  • 缓存已计算的 K, V:对于前序列位置的键值对,只需计算一次并存储。

  • 仅对新增 Query 进行点乘:第 n n n 步仅需计算 Q n Q_n Qn 与所有缓存 K 的点乘,得到长度为 n n n 的注意力权重,再加权叠加对应的 V。

3.2 矩阵形式展开

  • 第 1 步:如前,无缓存,计算
    A t t e n t i o n 1 = s o f t m a x ( Q 1 K 1 ⊤ / d ) , V 1 Attention_1 = \mathrm{softmax}(Q_1K_1^\top/\sqrt d),V_1 Attention1=softmax(Q1K1/d ),V1.

  • 第 2 步

    • 新增 Q 2 ∈ R 1 × d Q_2\in\mathbb{R}^{1\times d} Q2R1×d

    • 缓存矩阵已扩展为

      image.png

    • 只做一次 1 × 2 1\times 2 1×2 点乘:

      A t t e n t i o n 2 = s o f t m a x ( Q 2 K c a c h e ⊤ d ) , V c a c h e Attention_2 = \mathrm{softmax}\Bigl(\tfrac{Q_2 K_{\mathrm{cache}}^\top}{\sqrt d}\Bigr),V_{\mathrm{cache}} Attention2=softmax(d Q2Kcache),Vcache,

      输出即为所需的 1 × d 1\times d 1×d 向量。

  • 第 n 步

    K c a c h e ∈ R n × d , V c a c h e ∈ R n × d , A t t e n t i o n n = s o f t m a x ( Q n K c a c h e ⊤ d ) , V c a c h e K_{\mathrm{cache}}\in\mathbb{R}^{n\times d},\quad V_{\mathrm{cache}}\in\mathbb{R}^{n\times d},\quad Attention_n = \mathrm{softmax}\Bigl(\tfrac{Q_n K_{\mathrm{cache}}^\top}{\sqrt d}\Bigr),V_{\mathrm{cache}} KcacheRn×d,VcacheRn×d,Attentionn=softmax(d QnKcache),Vcache.

3.3 计算复杂度对比

模式每步复杂度总体复杂度(生成长度 N N N
无 Cache O ( n 2 d ) O(n^2 d) O(n2d) O ( N 3 d ) O(N^3 d) O(N3d)
有 KV Cache O ( n d ) O(n d) O(nd) ∑ n = 1 N O ( n d ) = O ( N 2 d ) \displaystyle\sum_{n=1}^N O(n d)=O(N^2 d) n=1NO(nd)=O(N2d)
  • 加速比:从二次方级别 O ( n 2 ) O(n^2) O(n2) 降到线性级别 O ( n ) O(n) O(n),对长序列提升显著。

4. 总结

  1. 多头注意力(Multi-Head)
    每个 head 独立缓存自己的 K, V 矩阵,计算时分别点乘再拼接。总体计算与存储线性可扩展。

  2. 缓存管理

    • 内存占用:缓存矩阵大小随生成长度增长,应考虑清理过旧不再需要的序列(如 sliding window)。

    • Batch 推理:对多条序列并行生成时,可为每条序列维护独立缓存,或统一按最大长度对齐。

  3. 硬件优化

    • 内存带宽:KV Cache 减少重复内存载入,对带宽友好;

    • 并行度:线性点乘更易与矩阵乘加(GEMM)指令级并行融合。

  4. 实践中常见问题

    • Cache 不命中:若使用 prefix-tuning 等技术动态修改 key/value,需谨慎处理缓存一致性。
    • 数值稳定性:长序列高维 softmax 易出现梯度消失/爆炸,可结合温度系数或分段归一化。

5. GPT-2 中 KV 缓存的实现分析

GPT-2(以及许多其他基于 Transformer 的自回归模型)在生成文本时,为了提高效率,会使用一种称为 KV 缓存 (Key-Value Cache) 的机制。其核心思想是:在生成第 t 个 token 时,计算注意力所需的键 (Key) 和值 (Value) 向量可以部分来自于已经生成的 t-1 个 token。通过缓存这些历史的 K 和 V 向量,可以避免在每一步生成时都对整个已生成序列重新进行昂贵的 K 和 V 计算。

5.1 缓存的数据结构与类型

Hugging Face Transformers 库为 GPT-2 提供了灵活的缓存管理机制,主要通过 Cache 基类及其子类实现。

  • Cache (基类): 定义了缓存对象的基本接口,例如 update (更新缓存) 和 get_seq_length (获取当前缓存的序列长度) 等方法。
  • DynamicCache:
    • 这是自回归生成时最常用的缓存类型。
    • 它允许缓存的序列长度动态增长。当生成新的 token 时,新计算出的 K 和 V 向量会被追加到已有的缓存后面。
    • 不需要预先分配固定大小的内存,更加灵活,但可能在内存管理上有一些开销。
  • StaticCache:
    • 在创建时就需要预先分配固定大小的内存空间来存储 K 和 V 向量。
    • 适用于已知最大生成长度或需要更可控内存占用的场景。
    • 如果生成的序列长度超过了预分配的大小,可能会出错或需要特殊处理。
  • EncoderDecoderCache:
    • 主要用于 Encoder-Decoder 架构的模型 (如 T5, BART)。
    • 它内部会分别管理编码器-解码器注意力(交叉注意力)的 KV 缓存和解码器自注意力的 KV 缓存。
    • GPT-2 是一个仅解码器 (Decoder-only) 模型,所以主要关注自注意力的缓存。
# 相关类的导入,展示了缓存工具的多样性
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache

5.2 在注意力机制 (GPT2Attention) 中使用缓存

GPT2Attention 类的 forward 方法是 KV 缓存机制的核心应用点。

class GPT2Attention(nn.Module):  ...  def forward(  self,  hidden_states: Optional[Tuple[torch.FloatTensor]],  layer_past: Optional[Tuple[torch.Tensor]] = None, # 旧版本的缓存参数名  past_key_value: Optional[Cache] = None,           # 新版本的缓存对象  attention_mask: Optional[torch.FloatTensor] = None,  head_mask: Optional[torch.FloatTensor] = None,  use_cache: Optional[bool] = False,  output_attentions: Optional[bool] = False,  cache_position: Optional[torch.LongTensor] = None, # 指示新token在缓存中的位置  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:  # 1. 计算当前输入 hidden_states 的 Q, K, V        # self.c_attn 是一个线性层,通常一次性计算出 Q, K, V 然后分割  query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)  # 2. 将 Q, K, V 重塑为多头形式 (batch_size, num_heads, seq_len, head_dim)        query = self._split_heads(query, self.num_heads, self.head_dim)  key = self._split_heads(key, self.num_heads, self.head_dim)  value = self._split_heads(value, self.num_heads, self.head_dim)  # 3. KV 缓存处理  if past_key_value is not None:  # 如果是 EncoderDecoderCache,根据是否交叉注意力选择正确的缓存  if isinstance(past_key_value, EncoderDecoderCache):  # ... (GPT-2 不直接使用此逻辑,但展示了其通用性)  pass  # 使用 cache_position 来更新缓存中的特定位置  cache_kwargs = {"cache_position": cache_position}  # 调用缓存对象的 update 方法  # key 和 value 是当前新计算的 K, V            # self.layer_idx 标识当前是哪一层的缓存  key, value = past_key_value.update(key, value, self.layer_idx, cache_kwargs)  # 此时的 key 和 value 包含了历史信息和当前新计算的信息  # 4. 计算注意力权重 (Q @ K^T)        # ...        attn_weights = torch.matmul(query, key.transpose(-1, -2))  # ... 应用注意力掩码 (causal mask, padding mask) ...  # 5. 计算注意力输出 (attn_weights @ V)        attn_output = torch.matmul(attn_weights, value)  # ... 合并多头,返回结果 ...  if use_cache:  # 如果使用缓存,则 present_key_value 就是更新后的 past_key_value            present_key_value = past_key_value  else:  present_key_value = None  return attn_output, present_key_value # 返回注意力的输出和更新后的缓存

关键点解释:

  • past_key_value (或 layer_past): 这是从上一个时间步或上一个调用传递过来的缓存对象。它包含了到目前为止所有先前 token 的 K 和 V 向量。
  • cache_position: 这是一个非常重要的参数,尤其是在使用了诸如 Flash Attention 2 等更高级的注意力实现时。它告诉缓存 update 方法以及注意力计算函数,新的 K 和 V 向量应该被放置在缓存张量的哪个位置。这对于正确地处理填充(padding)和动态序列长度至关重要。例如,如果当前输入的是第 t 个 token(从0开始计数),cache_position 可能就是 t
  • self.layer_idx: Transformer 模型通常由多个相同的注意力层堆叠而成。每一层都有自己独立的 KV 缓存。layer_idx 用于标识当前正在处理的是哪一层的缓存,确保数据被正确地存取。
  • use_cache: 控制是否使用和返回缓存。在训练时通常为 False(除非进行特定类型的训练,如 teacher forcing 的逐token训练),在推理(生成)时为 True

5.3 缓存的更新机制 (Cache.update)

Cache 对象的 update 方法是实现缓存的核心。虽然具体的实现会因 DynamicCacheStaticCache 而异,但其基本逻辑是:

class DynamicCache(Cache):  def __init__(self):  self.key_cache: List[torch.Tensor] = [] # 每层一个 tensor        self.value_cache: List[torch.Tensor] = [] # 每层一个 tensor        self.seen_tokens = 0 # 已缓存的token数量  def update(  self,  key_states: torch.Tensor,    # 新计算的 key        value_states: torch.Tensor,  # 新计算的 value        layer_idx: int,              # 当前层索引  cache_kwargs: Optional[Dict[str, Any]] = None,  ) -> Tuple[torch.Tensor, torch.Tensor]:  # 获取 cache_position        cache_position = cache_kwargs.get("cache_position")  # 如果是第一次更新这一层 (或缓存为空)  if layer_idx >= len(self.key_cache):  # 初始化该层的缓存张量  # ... 根据 key_states 和 value_states 的形状以及预估的最大长度(或动态调整)  self.key_cache.append(torch.zeros_like(key_states_preallocated))  self.value_cache.append(torch.zeros_like(value_states_preallocated))  # 将新的 key_states 和 value_states 写入到缓存的指定位置  # 对于 DynamicCache,通常是直接拼接或在预分配空间中按位置写入  if cache_position is not None:  # 使用 cache_position 精确地更新缓存的特定部分  # 例如: self.key_cache[layer_idx][:, :, cache_position, :] = key_states            #       self.value_cache[layer_idx][:, :, cache_position, :] = value_states            # 这里的维度可能需要根据实际实现调整  # 重要的是理解 cache_position 的作用  # 例如,如果 key_states 的形状是 (batch, num_heads, new_seq_len, head_dim)            # cache_position 的形状可能是 (batch, new_seq_len) 或广播的 (new_seq_len)            # 需要将 key_states 放置到 self.key_cache[layer_idx] 的正确"槽位"  # 对于自回归,通常 new_seq_len = 1            self.key_cache[layer_idx].index_copy_(dim=2, index=cache_position, source=key_states)  self.value_cache[layer_idx].index_copy_(dim=2, index=cache_position, source=value_states)  # 更新已见过的token数量  self.seen_tokens = cache_position[-1] + 1 # 取最后一个新token的位置加1  else: # 旧的、不使用 cache_position 的逻辑(通常是简单拼接)  self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)  self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)  self.seen_tokens += key_states.shape[2]  # 返回包含所有历史信息(包括刚更新的)的 K 和 V 状态  return self.key_cache[layer_idx], self.value_cache[layer_idx]

update 方法的关键职责:

  1. 接收当前新计算的 key_statesvalue_states
  2. 根据 layer_idx 找到对应层的缓存。
  3. (可选,但推荐)使用 cache_position 将新的 K, V 向量精确地放置到缓存张量的正确位置。这对于处理批处理中不同样本有不同历史长度的情况(例如,在束搜索beam search后或 speculative decoding 后),或者在有填充 token 时非常重要。
  4. 返回完整的、包含所有历史信息和当前新信息的 K, V 向量,供后续的注意力计算使用。
  5. 更新内部状态,如已缓存的 token 数量 (seen_tokens)。

5.4 在模型整体 (GPT2Model) 的 forward 方法中处理

GPT2Modelforward 方法负责协调整个模型的流程,包括缓存的初始化、传递和 cache_position 的计算。

class GPT2Model(GPT2PreTrainedModel):  def forward(  self,  input_ids: Optional[torch.LongTensor] = None,  past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, # 旧版缓存元组  attention_mask: Optional[torch.FloatTensor] = None,  # ...  use_cache: Optional[bool] = None,  output_attentions: Optional[bool] = None,  output_hidden_states: Optional[bool] = None,  return_dict: Optional[bool] = None,  cache_position: Optional[torch.LongTensor] = None,  ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:  # ... (处理输入ID和嵌入) ...  inputs_embeds = self.wte(input_ids) # 词嵌入  position_embeds = self.wpe(position_ids) # 位置嵌入  hidden_states = inputs_embeds + position_embeds  # 1. 缓存初始化和类型转换  if use_cache:  if past_key_values is None: # 如果是第一次调用或没有提供缓存  # 根据配置决定使用哪种缓存,通常是 DynamicCache                # 例如:self.config.cache_implementation == "dynamic"  past_key_values = DynamicCache()  elif not isinstance(past_key_values, Cache):  # 为了兼容旧的元组形式的缓存,将其转换为新的 Cache 对象  past_key_values = DynamicCache.from_legacy_cache(past_key_values)  # else: past_key_values 保持为 None  # 2. 计算 cache_position        if cache_position is None: # 如果外部没有提供 cache_position            # 获取当前缓存中已有的 token 数量  past_seen_tokens = past_key_values.get_seq_length(self.config.num_hidden_layers) if past_key_values is not None else 0  # 当前输入序列的长度  current_seq_length = inputs_embeds.shape[1]  # cache_position 从 past_seen_tokens 开始,长度为 current_seq_length            cache_position = torch.arange(  past_seen_tokens, past_seen_tokens + current_seq_length, device=inputs_embeds.device  )  # else: 使用外部传入的 cache_position  # ... (准备注意力掩码,考虑因果关系和缓存长度) ...  # 3. 逐层传递和更新缓存  all_hidden_states = () if output_hidden_states else None  all_self_attentions = () if output_attentions else None  # next_decoder_cache 用于收集下一轮的缓存 (如果 use_cache 为 True)        # 在新的 Cache 对象设计中,past_key_values 本身会被原地更新或返回更新后的版本  # 因此,这个 next_decoder_cache 可能不再是必需的,或者其角色由 past_key_values 自身承担  for i, block in enumerate(self.h): # self.h 是 GPT2Block 的列表  # ...  # 将当前层的缓存 (如果存在) 和 cache_position 传递给 GPT2Block            # GPT2Block 内部会再将其传递给 GPT2Attention            layer_outputs = block(  hidden_states,  layer_past=None, # 旧参数,通常为None  attention_mask=extended_attention_mask,  head_mask=head_mask[i],  encoder_hidden_states=None,  encoder_attention_mask=None,  use_cache=use_cache,  output_attentions=output_attentions,  past_key_value=past_key_values, # 传递整个缓存对象  cache_position=cache_position,  )  hidden_states = layer_outputs[0] # 更新 hidden_states            # 如果 use_cache,block 会返回更新后的缓存,这里 past_key_values 已被更新  # (在 Cache 对象实现中,update 方法通常返回更新后的完整缓存状态,  #  或者直接在对象内部修改,取决于具体实现)  # ... (处理输出) ...  return BaseModelOutputWithPast(  last_hidden_state=hidden_states,  past_key_values=past_key_values if use_cache else None, # 返回更新后的缓存  hidden_states=all_hidden_states,  attentions=all_self_attentions,  )

5.5 因果掩码 (Causal Mask) 与 KV 缓存的配合

在自回归生成中,模型只能注意到当前 token 及其之前的所有 token,不能注意到未来的 token。这是通过因果掩码实现的。当使用 KV 缓存时,因果掩码的构建需要考虑到缓存中已有的 token 数量。

class GPT2Attention(_GPT2Attention):  def _update_causal_mask(  self,  attention_mask: torch.Tensor, # 原始的 attention_mask (可能包含 padding)        input_tensor: torch.Tensor,   # 当前输入的 hidden_states        cache_position: torch.Tensor,  past_key_values: Cache,       # 当前的缓存对象  output_attentions: bool,  ):  # 获取当前输入的序列长度 (通常为1,在自回归生成的每一步)  input_seq_length = input_tensor.shape[1]  # 获取缓存中已有的序列长度  past_seen_tokens = past_key_values.get_seq_length(self.layer_idx)  # 总的上下文长度 = 缓存长度 + 当前输入长度  total_context_length = past_seen_tokens + input_seq_length  # _prepare_4d_causal_attention_mask_with_cache_position 会生成一个正确的掩码  # 这个掩码会确保:  # 1. 查询 Q (来自当前输入) 只能注意到键 K (来自缓存+当前输入) 中对应位置及之前的部分。  # 2. 处理好 padding (如果 attention_mask 中有指示)。  # 形状通常是 (batch_size, 1, query_length, key_length)        # 其中 query_length 是当前输入的长度 (如1)  # key_length 是总的上下文长度 (past_seen_tokens + input_seq_length)        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(  attention_mask,  input_shape=(input_tensor.shape[0], input_seq_length), # 当前输入的形状  target_length=total_context_length, # K, V 的总长度  dtype=input_tensor.dtype,  cache_position=cache_position, # 关键!用于确定当前 Q 在 K,V 序列中的相对位置  )  return causal_mask

_prepare_4d_causal_attention_mask_with_cache_position 这个辅助函数会创建一个上三角矩阵(或类似结构),其中未来的位置会被掩盖掉(例如,设置为一个非常小的负数,以便 softmax 后变为0)。cache_position 在这里的作用是,确保即使当前查询 Q 的序列长度很短(例如为1),它在与历史的 K, V 进行比较时,依然能正确地只关注到历史和当前 K, V 中该 Q 之前的部分。

5.6 支持多种高效的注意力实现

Hugging Face Transformers 库允许 GPT-2(以及其他模型)利用更高效的注意力后端实现,例如:

  • eager: PyTorch 的标准、原生注意力实现。
  • sdpa (Scaled Dot Product Attention): PyTorch 2.0 引入的高度优化的注意力函数 torch.nn.functional.scaled_dot_product_attention。它通常比 eager模式更快,内存效率也更高,并且可以自动选择最优的底层实现(如 FlashAttention 或 memory-efficient attention)。
  • flash_attention_2: 直接集成 FlashAttention v2 库。这是一种专门为现代 GPU 设计的、IO 感知的精确注意力算法,速度非常快,内存占用小。

KV 缓存机制的设计需要与这些高效实现兼容。例如,torch.nn.functional.scaled_dot_product_attention 和 FlashAttention 都支持直接传入包含历史和当前信息的完整 K, V 张量。cache_position 在这里尤为重要,因为它可以帮助这些高效后端理解哪些部分是新的,哪些是旧的,以及如何正确应用因果掩码。

# 在 GPT2Attention 的 forward 方法中
self.config._attn_implementation 存储了选择的注意力实现方式 ("eager", "sdpa", "flash_attention_2")  ... (计算 query, key, value) ...  
... (更新 key, value 使用 past_key_value 和 cache_position) ...  
此时 key 和 value 是拼接/更新后的完整 K, V  if self.config._attn_implementation == "sdpa":  # 使用 PyTorch SDPA    # is_causal=True 会自动应用因果掩码  # attn_mask 可能需要根据 SDPA 的要求进行调整  attn_output = torch.nn.functional.scaled_dot_product_attention(  query, key, value, attn_mask=adjusted_attn_mask, dropout_p=self.attn_dropout.p, is_causal=True  )  
elif self.config._attn_implementation == "flash_attention_2":  # from flash_attn import flash_attn_func  # 可能需要对 query, key, value 的形状或数据类型进行调整以适应 flash_attn_func    # causal=True 会应用因果掩码  attn_output = flash_attn_func(  query.transpose(1, 2), # FlashAttention 可能期望 (batch, seq_len, num_heads, head_dim)        torch.stack((key.transpose(1,2), value.transpose(1,2)), dim=0), # K, V 打包  dropout_p=self.attn_dropout.p,  causal=True,  )  
else: # "eager"  # ... (标准的 PyTorch matmul 实现) ...

5.7 KV 缓存的完整工作流程 (自回归生成)

5.7.1 初始步骤 (t=0):
  • 用户提供初始的 input_ids (例如,一个 [BOS] token 或者一段提示文本)。
  • past_key_valuesNone
  • 模型 forward 方法被调用。
  • use_cache 通常为 True
  • 初始化一个空的 DynamicCache 对象作为 past_key_values
  • 计算 cache_position,此时它通常是从 0 开始的序列 (e.g., torch.arange(0, initial_input_len)).
  • 对于每一注意力层:
    • 计算当前 input_ids 对应的 Q, K, V。
    • 由于 past_key_values 刚被初始化(内部缓存为空),update 方法会将这些新计算的 K, V 存入缓存的第一批位置。
    • 使用这些 K, V (此时它们只包含当前输入的信息) 和 Q 进行注意力计算。
  • 模型输出 logits (用于预测下一个 token) 和更新后的 past_key_values (现在包含了第一个输入的 K,V)。
5.7.2 后续步骤 (t > 0):
  • 从上一步的 logits 中采样得到新的 input_ids (通常是一个新的 token)。
  • 将上一步返回的 past_key_values (包含了 t-1 步及之前所有 token 的 K,V) 作为输入传递给模型。
  • 模型 forward 方法再次被调用。
  • use_cacheTrue
  • 计算 cache_position。此时,past_key_values.get_seq_length() 会返回已缓存的 token 数量 (例如 t)。新的 cache_position 会是 torch.tensor([t]),表示这个新 token 是序列中的第 t+1 个元素 (如果从1开始计数的话,或者第 t 个位置如果从0开始计数)。
  • 对于每一注意力层:
    • 只对新输入的单个 token 计算其 Q, K, V (这些是"小"张量)。
    • 调用 past_key_values.update(new_key, new_value, layer_idx, cache_kwargs={"cache_position": cache_position})
      • update 方法会将这个新 token 的 K, V 追加到对应层缓存中已有的 K, V 之后,并返回完整的 K (包含所有 t+1 个 token) 和完整的 V。
    • 使用新 token 的 Q 和完整的 (历史+当前) K, V 计算注意力。因果掩码会确保 Q 只注意到 K,V 中它自己及之前的部分。
  • 模型输出 logits 和再次更新后的 past_key_values

这个过程一直重复,直到生成了 [EOS] token 或达到最大长度。

KV 缓存的显著优势

  1. 避免冗余计算: 这是最核心的优势。在生成第 t 个 token 时,前 t-1 个 token 的 K 和 V 向量已经计算并存储在缓存中,无需重新计算。注意力机制只需要为新的当前 token 计算 K 和 V,然后将它们与缓存中的历史 K,V 结合起来。
  2. 显著提高生成速度: 尤其对于长序列生成,每次迭代的计算量从 O(N²)(N为当前总长度)降低到接近 O(N)(主要是新 Q 与历史 K,V 的交互),因为主要计算瓶颈(K,V的生成)只针对新token进行。
  3. 支持高效的批处理生成: 虽然每个样本在批次中可能有不同的已生成长度(特别是在使用可变长度输入或某些采样策略时),通过 cache_position 和可能的填充/掩码机制,KV 缓存可以有效地处理这种情况。
  4. 与先进注意力实现的兼容性: 如前所述,KV 缓存的设计与 SDPA、FlashAttention 等高效后端良好集成,使得模型可以同时享受到算法优化和底层硬件加速的好处。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/83818.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2025年渗透测试面试题总结-各厂商二面试题02(题目+回答)

网络安全领域各种资源,学习文档,以及工具分享、前沿信息分享、POC、EXP分享。不定期分享各种好玩的项目及好用的工具,欢迎关注。 目录 各厂商二面试题02 模块六:基础技术扩展 1. HTTP请求方式 2. 域名解析工具与技术 3. Web十…

专业漏洞扫描机构如何助力企业保障安全并提升竞争力?

在这个信息化的当下,专业漏洞扫描机构扮演着至关重要的角色。他们运用专业的技术和手段,对各种软件和系统进行细致的漏洞检测,确保其安全可靠,同时帮助企业提高产品质量和市场竞争力。 服务项目 我们专注于向客户供应周到详尽的…

卷积神经网络中的二维卷积与三维卷积详解

【内容摘要】 本文聚焦卷积神经网络中的二维卷积与三维卷积,详细解析两者的区别、操作原理及应用场景,涵盖二维/三维卷积操作示意图、多通道输入处理方式,以及RGB图像不采用三维卷积的原因,助力理解不同卷积类型的特性与适用场景。…

Oracle 的 ASSM 表空间

Oracle 的 ASSM(Automatic Segment Space Management)表空间 是一种自动管理段空间的技术,通过位图(Bitmap)机制跟踪数据块的使用情况,替代传统的手动管理(MSSM,即 Freelist 管理&am…

蚂蚁金服大数据面经及参考答案

Java 如何保证跨平台性?请从 JVM 底层适配机制及向上提供的统一接口角度说明 Java 的跨平台性是其核心优势之一,依赖于 JVM(Java Virtual Machine)的底层适配机制和向上层提供的统一接口。从底层来看,JVM 针对不同操作系统和硬件平台进行了定制化实现,负责解析和执行 Ja…

P1009 [NOIP 1998 普及组] 阶乘之和

题目描述 用高精度计算出 S1!2!3!⋯n!(n≤50)。 其中 ! 表示阶乘,定义为 n!n(n−1)(n−2)⋯1。例如,5!54321120。 输入格式 一个正整数 n。 输出格式 一个正整数 S,表示计算结果。 输入输出样例 输入 3 输出…

Python 的 os 库常见使用方法(操作目录及文件)

前言: os 模块是 Python 标准库中用于与操作系统交互的核心模块,提供了许多操作文件和目录的功能。以下是常见的使用方法: 1. 目录操作 方法功能说明示例os.getcwd()获取当前工作目录print(os.getcwd())os.chdir(path)切换当前工作目录os.ch…

vue3 el-table实现字段可编辑

在Vue 3中,如果你想让el-table(Element Plus的表格组件)的字段可编辑,你可以通过以下方式来实现: 使用cell-mouse-enter和cell-mouse-leave事件动态显示编辑图标或控件 你可以在鼠标进入单元格时显示一个编辑图标或输…

基于shardingsphere的分库分表方案

一、准备docker容器 启动两个mysql的docker容器 docker run -v /root/mysql_volume/data:/var/lib/mysql -v /root/mysql_volume/conf:/etc/mysql/conf.d -v /root/mysql_volume/my.cnf:/etc/my.cnf -p 3306:3306 --name mysql --restartalways --privilegedtrue -e MYSQL_RO…

SearxNG本地搜索引擎

SearxNG 是一个强大、开源的 元搜索引擎(meta search engine),它不会存储用户信息,注重隐私保护,并支持从多个搜索引擎聚合结果,用户可以自建部署,打造一个无广告、可定制的搜索平台。 🔍 什么是 SearxNG? SearxNG 是 Searx 的一个积极维护的分支(fork),意在改进…

Vue3.5 企业级管理系统实战(十九):菜单管理

篇幅原因,本节先探讨菜单管理页面增删改查相关功能,角色菜单,菜单权限,动态菜单等内容放在后面。 1 菜单 api 在 src/api/menu.ts 中添加菜单 api,代码如下: //src/api/menu.ts import service from &qu…

【android bluetooth 协议分析 01】【HCI 层介绍 8】【ReadLocalVersionInformation命令介绍】

1. HCI_Read_Local_Version_Information 命令介绍 1. 功能(Description) HCI_Read_Local_Version_Information 命令用于读取本地 Bluetooth Controller 的版本信息,包括 HCI 和 LMP 层的版本,以及厂商 ID 和子版本号。 这类信息用…

React底层架构深度解析:从虚拟DOM到Fiber的演进之路

一、虚拟DOM:性能优化的基石 1.1 核心工作原理 React通过JSX语法将组件转换为轻量级JavaScript对象(即虚拟DOM),而非直接操作真实DOM。这一过程由React.createElement()实现,其结构包含元素类型、属性和子节点等信息&a…

从AlphaGo到ChatGPT:AI技术如何一步步改变世界?

从AlphaGo到ChatGPT:AI技术如何一步步改变世界? 这里给大家分享一个人工智能学习网站。点击跳转到网站。 https://www.captainbed.cn/ccc 前言 在科技发展的历史长河中,人工智能(AI)技术无疑是最为璀璨的明珠之一。从…

关于在Unity项目中使用Post Processing插件打包到web端出现的问题

关于在Unity项目中使用Post Processing插件打包到web端出现的问题 解决方法:是不激活摄像机上的Post Processing有关组件,拉低场景中的Directional Light平行光的强度进行web端打包。 (烘焙灯光时是可以激活。) web端支持这个Pos…

MySQL - 如何突破单库性能瓶颈

数据库服务器硬件优化 我们来看看对数据库所在的服务器是如何进行优化的,服务器是数据库的宿主,其性能直接影响了数据库的性能,所以服务器的优化也是数据库优化的第一步。 数据库服务器通常是从 CPU、内存、磁盘三个角度进行硬件优化的&…

用 CodeBuddy 搭建「MiniGoal 小目标打卡器」:一次流畅的 UniApp 开发体验

我正在参加CodeBuddy「首席试玩官」内容创作大赛,本文所使用的 CodeBuddy 免费下载链接:腾讯云代码助手 CodeBuddy - AI 时代的智能编程伙伴 在日常生活中,我们总是希望能够坚持一些小习惯,比如每天锻炼十分钟、读一页书、早睡十分…

OpenCV 环境搭建与概述

// //OpenCV-4.11.0 C VS2019 // 一、OpenCV学习路线 1、入门: OpenCV图像读写、视频读写、基本像素处理、基本卷积处理、基本C开发知识。 2、初级: OpenCV自定义卷积操作、图像梯度、边缘提取、二值分析、视频分析、形态学处理、几何变换与透视变换。 3、中级: 角点查找、BL…

如何快速更换电脑浏览器ip:教程与注意事项

无论是为了访问地域限制内容、保护隐私,还是解决网络问题,快速更换浏览器IP地址的需求日益增多。以下是快速更换电脑浏览器IP地址的几种常用方法及注意事项,结合了多种场景下的解决方案: 一、快速更换浏览器IP的方法 1. 代理服务…

【kafka】kafka概念,使用技巧go示例

1. Kafka基础概念 1.1 什么是Kafka? Kafka是一个分布式流处理平台,用于构建实时数据管道和流式应用。核心特点: 高吞吐量:每秒可处理百万级消息持久化存储:消息按Topic分区存储在磁盘分布式架构:支持水平…