出发点
上一篇解析了Chatglm2-6b的模型架构,并和Chatglm-6b进行对比,但是留下了几个问题(哭)这一篇的目的是讲明白attention和rotaryEmbedding,解决问题,并实现整体目标,完全替代modeling_chatglm.py,并将代码缩减到一半儿。
selfattention
class SelfAttention(torch.nn.Module):"""Parallel self-attention layer abstract class.Self-attention layer takes input with size [s, b, h]and returns output of the same size."""def __init__(self, config: ChatGLMConfig, layer_number, device=None):super(SelfAttention, self).__init__()self.layer_number = max(1, layer_number)self.projection_size = config.kv_channels * config.num_attention_heads# 128*32=4096 hidden_size# Per attention head and per partition values.self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads# 128 每个attention头的hidden_sizeself.num_attention_heads_per_partition = config.num_attention_heads# 32 attention头数self.num_multi_query_groups_per_partition = config.multi_query_group_num# 2 分了多少组self.qkv_hidden_size = (self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num)# 4096+2*128*2=4608 qkv对应的hidden_size# 稍微解释一下为什么不是4096*3,因为这里使用了GQA的思想,下文会简单介绍一下self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,bias=config.add_bias_linear or config.add_qkv_bias,device=device, **_config_to_kwargs(config))self.core_attention = CoreAttention(config, self.layer_number)# Output.self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,device=device, **_config_to_kwargs(config))def forward(self, hidden_states, rotary_pos_emb, kv_cache=None, use_cache=True):# hidden_states: [sq, b, h]# =================================================# Pre-allocate memory for key-values for inference.# =================================================# =====================# Query, Key, and Value# =====================# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]mixed_x_layer = self.query_key_value(hidden_states)(query_layer, key_layer, value_layer) = mixed_x_layer.split([self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,],dim=-1,)query_layer = query_layer.view(query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head))key_layer = key_layer.view(key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head))value_layer = value_layer.view(value_layer.size()[:-1]+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head))# apply relative positional encoding (rotary embedding)if rotary_pos_emb is not None:query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)# adjust key and value for inferenceif kv_cache is not None:cache_k, cache_v = kv_cachekey_layer = torch.cat((cache_k, key_layer), dim=0)value_layer = torch.cat((cache_v, value_layer), dim=0)if use_cache:kv_cache = (key_layer, value_layer)else:kv_cache = Nonekey_layer = key_layer.unsqueeze(-2)key_layer = key_layer.expand(-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1)key_layer = key_layer.contiguous().view(key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head))# GQA的操作:重复多次到原始尺寸,即32,128value_layer = value_layer.unsqueeze(-2)value_layer = value_layer.expand(-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1)value_layer = value_layer.contiguous().view(value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head))# GQA的操作:重复多次到原始尺寸,即32,128# ==================================# core attention computation# ==================================context_layer = self.core_attention(query_layer, key_layer, value_layer)# 核心操作attention,和Chatglm-6b中attention_fn是一样的# =================# Output. [sq, b, h]# =================output = self.dense(context_layer)return output, kv_cache
GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
可以看出来思想也比较朴素,MHA中query、key、value都是一对一的,这样虽然效果好,但是caches太多了。MQA中只有一组key和value,和多个query相对应,caches减少了,但是效果会不好。那GQA则取个平均,有g组key和value,每一组key和value都重复几次和query相对应。
GQA提供了MHA到MQA的自然过渡,当g=h时就是MHA,g=1时就是MQA,当1<g<h时,它只将KV Cache压缩到g/h,压缩率不如MQA,但同时也提供了更大的自由度,效果上更有保证。
这里也贴一下Fast Transformer Decoding: One Write-Head is All You Need
那这里就解决了两个问题:
- multi_query_group_num是GQA中要分组的数量
- kv_channels对应的是query、key、value每个头的hidden_size
coreattention
class CoreAttention(torch.nn.Module):def __init__(self, config: ChatGLMConfig, layer_number):super(CoreAttention, self).__init__()self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling# 对query、key层是否要进行缩放,实际是要缩放的self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32# softmax的精度要使用fp32self.layer_number = max(1, layer_number)# Per attention head and per partition values.self.hidden_size_per_partition = config.kv_channels * config.num_attention_heads# 128*32self.hidden_size_per_attention_head = config.kv_channels# 128self.num_attention_heads_per_partition = config.num_attention_heads# 32self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)# sqrt(d)的操作self.attention_dropout = torch.nn.Dropout(config.attention_dropout)def forward(self, query_layer, key_layer, value_layer):pytorch_major_version = int(torch.__version__.split('.')[0])if pytorch_major_version >= 2:query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]if query_layer.shape[2] == key_layer.shape[2]:# 只会在生成第一个token的时候,走这条路context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,is_causal=True)# 从这里可以看出来Chatglm2-6b完全就是一个decoder only的模型else:# 这时候query的长度是1,key的长度是总token的长度context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,None)context_layer = context_layer.permute(2, 0, 1, 3)new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)context_layer = context_layer.reshape(*new_context_layer_shape)else:# Raw attention scores# [b, np, sq, sk]output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))# [sq, b, np, hn] -> [sq, b * np, hn]query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)# [sk, b, np, hn] -> [sk, b * np, hn]key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)# preallocting input tensor: [b * np, sq, sk]matmul_input_buffer = torch.empty(output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,device=query_layer.device)# Raw attention scores. [b * np, sq, sk]matmul_result = torch.baddbmm(matmul_input_buffer,query_layer.transpose(0, 1), # [b * np, sq, hn]key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]beta=0.0,alpha=(1.0 / self.norm_factor),)# Chatglm-6b中将alpha放在了前面,让query单独除了一下,没啥结果上的差别# 关于torch.baddbmm多说一句,因为beta=0,所以input选择empty没啥问题,反正要被跳过# change view to [b, np, sq, sk]attention_scores = matmul_result.view(*output_size)# ===========================# Attention probs and dropout# ===========================# attention scores and attention mask [b, np, sq, sk]if self.attention_softmax_in_fp32:attention_scores = attention_scores.float()if attention_scores.shape[2] == attention_scores.shape[3]:attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],device=attention_scores.device, dtype=torch.bool)attention_mask.tril_()attention_mask = ~attention_maskelse:attention_mask = None"""重点看一下这一小段代码,当sq=sk时(即query长度和key长度一致时,给了一个attention_mask)此时的attention_mask其实就是一个上三角为True、下三角为False的矩阵结合后面的 attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) 这一句的操作就是将上三角的scores值置为负无穷,这妥妥的就是decoder-only嘛当sq!=sk时,attention_mask即为空,即预测第二个token时,此时query长度为1,而key长度带着之前的cache,所以长度>1,此时不相等,attention_mask为空,后续也就没有啥操作了"""if attention_mask is not None:attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))attention_probs = F.softmax(attention_scores, dim=-1)attention_probs = attention_probs.type_as(value_layer)# This is actually dropping out entire tokens to attend to, which might# seem a bit unusual, but is taken from the original Transformer paper.attention_probs = self.attention_dropout(attention_probs)# =========================# Context layer. [sq, b, hp]# =========================# value_layer -> context layer.# [sk, b, np, hn] --> [b, np, sq, hn]# context layer shape: [b, np, sq, hn]output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))# change view [sk, b * np, hn]value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)# change view [b * np, sq, sk]attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)# matmul: [b * np, sq, hn]context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))# change view [b, np, sq, hn]context_layer = context_layer.view(*output_size)# [b, np, sq, hn] --> [sq, b, np, hn]context_layer = context_layer.permute(2, 0, 1, 3).contiguous()# [sq, b, np, hn] --> [sq, b, hp]new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)context_layer = context_layer.view(*new_context_layer_shape)return context_layer
这里多写一句,代码中有关于self.coeff的操作,即layer_number
在代码中self.norm_factor=self.coeff *math.sqrt(self.hidden_size_per_attention_head)
在计算attention_scores中除以了self.coeff *math.sqrt(self.hidden_size_per_attention_head)
然后在计算softmax之前又将attention_scores乘以了self.coeff
那不就相当于只是除以了math.sqrt(self.hidden_size_per_attention_head)嘛????
不知道为什么要有这个操作,感觉怪怪的,最主要的是不知道目的,有了解的可以解释一下,谢谢
之前Chatglm-6b的代码中就有这样的操作,当时没注意到(汗),这里的代码是直接删去了这个操作,完全没影响的。
当然了因为在pytorch_major_version >= 2中其实是没有和layer_number相关的操作,这个时候应该就能明白这个操作是无用的了。
RotaryEmbedding
class RotaryEmbedding(nn.Module):def __init__(self, dim, original_impl=False, device=None, dtype=None):super().__init__()inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))self.register_buffer("inv_freq", inv_freq)self.dim = dimself.original_impl = original_impldef forward_impl(self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000):"""Enhanced Transformer with Rotary Position Embedding.Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/transformers/rope/__init__.py. MIT License:https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license."""# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))# Create position indexes `[0, 1, ..., seq_len - 1]`seq_idx = torch.arange(seq_len, dtype=dtype, device=device)# Calculate the product of position index and $\theta_i$idx_theta = torch.outer(seq_idx, theta).float()cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)# this is to mimic the behaviour of complex32, else we will get different resultsif dtype in (torch.float16, torch.bfloat16, torch.int8):cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()return cachedef forward(self, max_seq_len, offset=0):return self.forward_impl(max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device)@torch.jit.script
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:# x: [sq, b, np, hn]sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)rot_dim = rope_cache.shape[-2] * 2# 32*2x, x_pass = x[..., :rot_dim], x[..., rot_dim:]# [:64],[64:] 将输入根据隐藏层维度,拆分得到两部分,只针对前部分x计算旋转位置信息# truncate to support variable sizesrope_cache = rope_cache[:sq]xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)# [q_0,q_1][q_2,q_3]rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)# [cos0,sin0][cos1,sin1]x_out2 = torch.stack([xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],# 对应复数的实部q_0*cos(m\theta)-q_1*sin(m\theta)# [q0, q2, ] *[cos0, cos1] - [q1, q3, ] *[sin0, sin1]xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],# 对应复数的虚部q_1*cos(m\theta)+q_0*sin(m\theta)# [q1, q3, ] *[cos0, cos1] + [q0, q2, ] *[sin0, sin1]],-1,)# q0cos0-q1sin0# q1cos0+q0sin0# q2cos1-q3sin1# q3cos1+q2sin1x_out2 = x_out2.flatten(3)return torch.cat((x_out2, x_pass), dim=-1)
这里就可以解释位置Embedding中传入的dim为什么是rotary_dim // 2了,因为它只对一半的hidden_size进行了位置编码,这也是很迷的一项操作,我没看到什么很好的解释,有了解原因的,欢迎指导,谢谢
最后一点代码量
到此基本就写完了代码,最后补充上两个函数和一点import
""" PyTorch ChatGLM model. """import math
import copy
import reimport torch
import torch.nn.functional as F
from torch import nn
from torch.nn import LayerNorm
from torch.nn.utils import skip_init
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
from transformers.modeling_utils import PreTrainedModel
from configuration_chatglm import ChatGLMConfigdef _config_to_kwargs(args):common_kwargs = {"dtype": args.torch_dtype,}return common_kwargsclass ChatGLMPreTrainedModel(PreTrainedModel):"""An abstract class to handle weights initialization anda simple interface for downloading and loading pretrained models."""is_parallelizable = Falseconfig_class = ChatGLMConfigbase_model_prefix = "transformer"_no_split_modules = ["GLMBlock"]
把这些代码保存成chatglm.py,放在chatglm2-6b的代码中,就可以正常使用了,使用方法和chatglm-6b是一样的
from chatglm import *
from transformers import AutoTokenizer
model_path = "/usr/downloads/chatglm2-6b"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = ChatGLMForConditionalGeneration.from_pretrained(model_path, trust_remote_code=True).half().cuda()prompt = '你好'
response = model.chat(tokenizer, prompt)
代码量在650行,原始代码量是1280,减少一半的代码的小目标基本实现(成功)
参数量
简单分析一下参数量,其实从模型结构里就能很明白的看出来了,我这里就是记录一下
# word embedding
65024*4096*2=532676608
# 最后一层后面的LN
4096
# 下面几个是每层都有的
# query_key_value
4608*4096=18874368
# query_key_value.bias
4608
# dense
4096*4096=16777216
# LN
2*4096
# dense_h_to_4h
4096*27392=112197632
# dense_4h_to_h
13696*4096=56098816# 28层
(18874368+4608+16777216+2*4096+112197632+56098816)*28=5710903296
5710903296+532676608+4096=6243584000
# 可以看出来主要的参数还是在word Embedding和dense_h_to_4h
结束语
这次解析了chatglm2-6b的代码,将代码缩减到650行,并分析了与chatglm-6b的区别,其实从结构里就可以看出来,它已经不是GLM的架构了,完全是一个decoder only的结构。改为了使用了RMSNorm、使用了GQA缩减caches、激活函数使用swiglu,基本就是这些了。
补充一点:经过查看代码,发现chatglm3-6b和chatglm2-6b的代码基本一模一样,只有在tokenizer处理输入的时候和返回response的时候有一点不一样,所以就不对chatglm3-6b做单独的介绍了。