大模型推理——MLA实现方案

1.整体流程

先上一张图来整体理解下MLA的计算过程

2.实现代码

import math
import torch
import torch.nn as nn# rms归一化
class RMSNorm(nn.Module):""""""def __init__(self, hidden_size, eps=1e-6):super().__init__()self.weight = nn.Parameter(torch.ones(hidden_size))self.variance_epsilon = epsdef forward(self, hidden_states):hidden_states = hidden_states.float()variance = hidden_states.pow(2).mean(-1, keepdim=True)hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)return self.weight * hidden_states.float()def rotate_half(x):x1, x2 = x.chunk(2, dim=-1)return torch.cat((-x2, x1), dim=-1)def apply_rotate_pos_emb(q, k, cos, sin, unsqueeze_dim=2):cos = cos.unsqueeze(unsqueeze_dim)sin = sin.unsqueeze(unsqueeze_dim)q_embed = (q * cos) + (rotate_half(q) * sin)k_embed = (k * cos) + (rotate_half(k) * sin)return q_embed, k_embed# 旋转位置编码
class RotaryEmbedding(nn.Module):def __init__(self, dim, max_seq_len=1024):super(RotaryEmbedding, self).__init__()self.dim = dimself.max_seq_len = max_seq_leninv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))t = torch.arange(max_seq_len).float().unsqueeze(1)freqs = t @ inv_freq.unsqueeze(0)freqs = torch.cat((freqs, freqs), dim=-1)self.register_buffer("cos_cached", freqs.cos())self.register_buffer("sin_cached", freqs.sin())def forward(self, q, k):cos = self.cos_cached[:q.shape[1], :].unsqueeze(0)sin = self.sin_cached[:q.shape[1], :].unsqueeze(0)return apply_rotate_pos_emb(q, k, cos, sin)class MLA(nn.Module):def __init__(self,dim,n_heads,q_lora_rank,kv_lora_rank,qk_nope_head_dim,qk_rope_head_dim,v_head_dim,max_seq_len,max_batch_size,mode):super().__init__()self.dim = dim  # 隐藏层维度self.n_heads = n_heads  # 总头数self.q_lora_rank = q_lora_rank  # q低秩压缩到的维度self.kv_lora_rank = kv_lora_rank  # k/v低秩压缩到的维度self.qk_nope_head_dim = qk_nope_head_dim    # q/k不带旋转位置编码的维度self.qk_rope_head_dim = qk_rope_head_dim    # q/k带旋转位置编码的维度self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim  # q/k的总维度,不带旋转位置编码的维度加上带旋转位置编码的维度self.v_head_dim = v_head_dim  # value的维度,等于不带旋转位置编码的k维度self.mode = modeself.max_seq_len = max_seq_lenself.max_batch_size = max_batch_sizeself.wq_a = nn.Linear(self.dim, self.q_lora_rank)  # q的降维矩阵self.q_norm = RMSNorm(self.q_lora_rank)self.wq_b = nn.Linear(self.q_lora_rank, self.n_heads * self.qk_head_dim)  # q的升维矩阵# 4096*128+128*4864 = 524,288 + 622592 = 1146880    4096*4864 = 19,922,944self.wkv_a = nn.Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)  # k/v的降维矩阵# nn.Linear(self.dim, self.kv_lora_rank)# nn.Linear(self.dim, self.qk_rope_head_dim)self.kv_norm = RMSNorm(self.kv_lora_rank)self.wkv_b = nn.Linear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))  # k/v的升维矩阵self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim)self.rotary_emb = RotaryEmbedding(self.qk_rope_head_dim)  # 旋转位置编码# 没有矩阵融合if self.mode == 'naive':self.register_buffer('k_cache',torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.qk_head_dim),persistent=False)self.register_buffer('v_cache',torch.zeros(self.max_batch_size, self.max_seq_len, self.n_heads, self.v_head_dim),persistent=False)# 有矩阵融合else:self.register_buffer('kv_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.kv_lora_rank),persistent=False)self.register_buffer('pe_cache', torch.zeros(self.max_batch_size, self.max_seq_len, self.qk_rope_head_dim),persistent=False)def forward(self, x, mask=None):bs, seq_len, _ = x.shapeq = self.wq_a(x)  # [bs, seq_len, q_lora_rank]q = self.q_norm(q)  # [bs, seq_len, q_lora_rank]q = self.wq_b(q)  # [bs, seq_len, n_heads * qk_head_dim]q = q.view(bs, seq_len, self.n_heads, self.qk_head_dim)  # [bs, seq_len, n_heads, qk_head_dim]q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim],dim=-1)  # q_nope shape:[bs, seq_len, n_heads, qk_nope_head_dim] q_pe shape:[bs, seq_len, n_heads, qk_rope_head_dim]kv = self.wkv_a(x)  # [bs, seq_len, kv_lora_rank + qk_rope_head_dim]kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim],dim=-1)  # kv shape:[bs, seq_len, kv_lora_rank] k_pe shape:[bs, seq_len, qk_rope_head_dim]k_pe = k_pe.unsqueeze(2)  # k_pe shape:[bs, seq_len, 1, qk_rope_head_dim]   一层共享一个keyq_pe, k_pe = self.rotary_emb(q_pe, k_pe)if self.mode == 'naive':q = torch.cat([q_nope, q_pe], dim=-1)  # * [bs, seq_len, n_heads, qk_head_dim]kv = self.kv_norm(kv)  # [bs, seq_len, kv_lora_rank)]kv = self.wkv_b(kv)  # [bs, seq_len, n_heads * (qk_nope_head_dim + v_head_dim)]kv = kv.view(bs, seq_len, self.n_heads, self.qk_nope_head_dim + self.v_head_dim)k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1)# k shape:[bs, seq_len, n_heads, qk_head_dim]self.k_cache[:bs, :seq_len, :, :] = kself.v_cache[:bs, :seq_len, :, :] = v# scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bs, :seq_len]) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)scores = torch.matmul(q.transpose(1, 2),self.k_cache[:bs, :seq_len, :, :].transpose(1, 2).transpose(2, 3) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim))scores = scores.transpose(1, 2)else:k_pe = k_pe.squeeze(2)wkv_b = self.wkv_b.weight  # [n_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]wkv_b = wkv_b.view(self.n_heads, -1,self.kv_lora_rank)  # [n_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank]q_nope = torch.einsum("bshd,hdc->bshc", q_nope,wkv_b[:, :self.qk_nope_head_dim])  # q_nope shape:[bs, seq_len, n_heads, kv_lora_rank]# q*k(T) = x*wq*(c*wkv_b[:, :self.qk_nope_head_dim])(T) = x*wq*wkv_b[:, :self.qk_nope_head_dim](T)*c(T)    c为压缩后的k/v# wq*wkv_b[:, :self.qk_nope_head_dim](T)作为q的投影矩阵  c可以替代原先的k,这样就可以直接使用压缩后的k/v计算注意力了,kv_cache时也只需存储压缩后的k/vkv = self.kv_norm(kv)self.kv_cache[:bs, :seq_len, :] = kv  # kv shape:[bs, seq_len, kv_lora_rank]self.pe_cache[:bs, :seq_len, :] = k_pe  # k_pe shape:[bs, seq_len, qk_rope_head_dim]scores_nope = torch.einsum("bshc,btc->bsht", q_nope,self.kv_cache[:bs, :seq_len, :])  # bshc btc -> bshc bct -> bshtscores_pe = torch.einsum("bshr,btr->bsht", q_pe,self.pe_cache[:bs, :seq_len, :])  # bshr btr -> bshr bt1r -> bshr bthr -> bshtscores = (scores_nope + scores_pe) / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)  # [bs, seq_len, n_heads, seq_len]if mask is not None:# mask shape:[bs, seq_len, seq_len]scores += mask.unsqueeze(2)scores = scores.softmax(dim=-1)if self.mode == 'naive':x = torch.einsum("bsht,bthd->bshd", scores,self.v_cache[:bs, :seq_len])  # bsht,bthd -> bhst, bhtd -> bhsd -> bshdelse:# scores * v = scores * c * wkv_b[:, -self.v_head_dim:]x = torch.einsum("bsht,btc->bshc", scores,self.kv_cache[:bs, :seq_len])  # x shape:[bs, seq_len, n_heads, kv_lora_rank]x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])  # bshc, hdc -> bshc,dch -> bsdh -> bshdx = x.contiguous().view(bs, seq_len, -1)x = self.wo(x) return xif __name__ == '__main__':torch.manual_seed(0)torch.set_printoptions(precision=3, sci_mode=False)x = torch.randn(1, 4, 16)dim = 16n_heads = 2q_lora_rank = 10kv_lora_rank = 6qk_nope_head_dim = 8qk_rope_head_dim = 4v_head_dim = 8max_seq_len = 10max_batch_size = 4mode = 'none'mla = MLA(dim=dim,n_heads=n_heads,q_lora_rank=q_lora_rank,kv_lora_rank=kv_lora_rank,qk_nope_head_dim=qk_nope_head_dim,qk_rope_head_dim=qk_rope_head_dim,v_head_dim=v_head_dim,max_seq_len=max_seq_len,max_batch_size=max_batch_size,mode=mode)print(mla(x))print(mla.kv_cache)

参考资料:

https://zhuanlan.zhihu.com/p/16730036197

https://github.com/wyf3/llm_related/tree/main/deepseek_learn

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/70484.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Nginx + Keepalived 实现高可用的负载均衡架构】

使用 Nginx Keepalived 可以实现高可用的负载均衡架构,确保在某个 Nginx 节点故障时,自动将流量转移到备用节点。以下是详细的实现步骤: 1. 架构概述 Nginx:作为负载均衡器,将流量分发到后端服务器。Keepalived&…

MySQL 8.0.41安装教程(2025年2月8号)

下载网址:https://www.mysql.com/cn/downloads/ 点击 我选择的是第二个离线安装 点击之后,选择直接下载: 下载完成双击: 我选择的是自定义安装: 右边默认已经存在我选择的8.0.41 点击红框中的,自定义安装路…

WPS中解除工作表密码保护(忘记密码)

1.下载vba插件 项目首页 - WPS中如何启用宏附wps.vba.exe下载说明分享:WPS中如何启用宏:附wps.vba.exe下载说明本文将详细介绍如何在WPS中启用宏功能,并提供wps.vba.exe文件的下载说明 - GitCode 并按照步骤安装 2.wps中点击搜索,输入开发…

RNN-day1-NLP基础

NLP基础 一、基本概念 自然语言处理:Natural Language Processing,主要目标是让计算机能够理解、解释和生成人类语言的数据。 1 基本概念 1.1NLP概念 语言:人类沟通的机构化系统,包括声音、书写符号、手势 自然语言:自然进化…

Python多版本管理

关注后回复 python 获取相关资料 ubuntu18.04 # ubuntu18 默认版本 Python 2.7.17 apt install python python-dev python-pip# ubuntu18 默认版本 Python 3.6.9 apt install python3 python3-dev python3-pip# ubuntu18 使用 python3.8 apt install python3.8 python3.8-dev#…

基于python多线程多进程爬虫的maa作业站技能使用分析

基于python多线程多进程爬虫的maa作业站技能使用分析 技能使用分析 多线程(8核) import json import multiprocessing import requests from multiprocessing.dummy import Pooldef maa(st):url "https://prts.maa.plus/copilot/get/"m …

2025.2.8——一、[护网杯 2018]easy_tornado tornado模板注入

题目来源:BUUCTF [护网杯 2018]easy_tornado 目录 一、打开靶机,整理信息 二、解题思路 step 1:分析已知信息 step 2:目标——找到cookie_secret step 3:构造payload 三、小结 一、打开靶机,整理信…

深度学习里面的而优化函数 Adam,SGD,动量法,AdaGrad 等 | PyTorch 深度学习实战

前一篇文章,使用线性回归模型逼近目标模型 | PyTorch 深度学习实战 本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started 本篇文章内容来自于 强化学习必修课:引领人工智能新时代【梗直哥瞿炜】 深度学习里面的而优化函数 …

番外篇:寄人篱下的岁月

在小冷的记忆深处,有一段时光如同一幅陈旧却又色彩浓烈的画卷,被岁月的尘埃轻轻覆盖着。那是他小时候,在二姨家度过的一年。这段经历,在他的生命中刻下了深深的印记,成为了他成长路上一段难以忘怀的过往,其…

腾讯混元3D创作引擎:一站式AI 3D创作平台的深度解析

在当今数字化的时代,3D内容创作正逐渐成为各行各业不可或缺的一部分。从游戏开发到影视制作,再到教育和广告行业,3D技术的应用场景日益广泛。然而,传统的3D创作往往需要大量的时间和技能积累,这使得许多非专业人士望而却步。为了打破这一壁垒,腾讯于2025年1月21日正式发布…

Chrome谷歌多开教程:实用方法与工具

不管是电子商务、技术测试、空投等不同专业领域,还是个人的工作和生活账号管理,使用不同的独立账户往往需要借助Chrome谷歌浏览器多开来提高效率。Chrome谷歌多开有哪些方法和工具?可以来参考以下实用内容。 一、Chrome谷歌多开方法与工具 1…

数据库操作与数据管理——Rust 与 SQLite 的集成

第六章:数据库操作与数据管理 第一节:Rust 与 SQLite 的集成 在本节中,我们将深入探讨如何在 Rust 中使用 SQLite 数据库,涵盖从基本的 CRUD 操作到事务处理、数据模型的构建、性能优化以及安全性考虑等方面。SQLite 是一个轻量…

Android原生开发问题汇总

Fragment顶部出现一个白条怎么办?父类布局搞事情。 layer-list被拉伸问题 Android之 ImageView android:src和tools:src的区别是什么? Android运行时权限的总结,以及EasyPermissions框架的使用 Android Studio添加EasyPemissions Android中module怎…

【AI实践】Cursor上手-跑通Hello World和时间管理功能

背景 学习目的:熟悉Cursor使用环境,跑通基本开发链路。 本人背景:安卓开发不熟悉,了解科技软硬件常识 实践 基础操作 1,下载安装安卓Android Studio 创建一个empty project 工程,名称为helloworld 2&am…

深度解析DeepSeek模型系列:从轻量级到超大规模(附DeepSeek硬件配置清单)

在人工智能领域,深度学习模型的选择对于任务的执行效率和精度至关重要。DeepSeek模型系列提供了多种不同参数量的版本,以满足不同场景下的需求。本文将详细解析DeepSeek模型系列的特点、适用场景以及硬件需求。 DeepSeek模型系列概览 DeepSeek模型系列…

LeetCodeHot 100 第一天

哈希组 1、两数之和使用的是HashMap,如果数字数目比较小可以使用数组作为Hash表,HashMap使用的函数市put,get,containsKey。 2、遇到判断字母异位词首先进行排序,本质上就是找字母异位词的共同之处,也就是…

讯飞绘镜(ai生成视频)技术浅析(五):视频生成

讯飞绘镜(AI生成视频)是一种先进的AI视频生成技术,能够将静态的分镜画面转换为动态视频,并使画面中的元素按照一定的逻辑和动作进行动态展示。 一、讯飞绘镜视频生成技术概述 讯飞绘镜的视频生成技术主要包含以下几个核心模块: 1.视频生成模型:包括生成对抗网络(GAN)…

LabVIEW铅酸蓄电池测试系统

本文介绍了基于LabVIEW的通用飞机铅酸蓄电池测试系统的设计与实现。系统通过模块化设计,利用多点传感器采集与高效的数据处理技术,显著提高了蓄电池测试的准确性和效率。 ​ 项目背景 随着通用航空的快速发展,对飞机铅酸蓄电池的测试需求也…

JVM虚拟机以及跨平台原理

相信大家已经了解到Java具有跨平台的特性,即“一次编译,到处运行”,例如在Windows下编写的程序,无需任何修改就可以在Linux下运行,这是C和C很难做到的。 那么,跨平台是怎样实现的呢?这就要谈及…

重新刷题求职2-DAY7

1.454. 四数相加 II 给你四个整数数组 nums1、nums2、nums3 和 nums4 &#xff0c;数组长度都是 n &#xff0c;请你计算有多少个元组 (i, j, k, l) 能满足&#xff1a; 0 < i, j, k, l < nnums1[i] nums2[j] nums3[k] nums4[l] 0 示例 1&#xff1a; 输入&#…