【2024 CSDN博客之星】技术洞察类:从DeepSeek-V3的成功,看MoE混合专家网络对深度学习算法领域的影响(MoE代码级实战)

目录

一、引言

1.1 本篇文章侧重点

1.2 技术洞察—MoE(Mixture-of-Experts,混合专家网络)

二、MoE(Mixture-of-Experts,混合专家网络)

2.1 技术原理

2.2 技术优缺点

2.3 业务代码实践

2.3.1 业务场景与建模

2.3.2 模型代码实现

2.3.3 模型训练与推理测试

2.3.4 打印模型结构 

三、总结


一、引言

经历了大模型2024一整年度的兵荒马乱,从年初的Sora文生视频到MiniMax顿悟后的开源,要说年度最大赢家,当属deepseek莫属:年中,deepseek-v2以其1/100的售价,横扫包括gpt4、qwen、百度等一系列商用模型;年底,deepseek-v3发布,以MoE为核心的专家网络技术,让其以极低的推理成本,获得了媲美gpt-4o的效果。

1.1 本篇文章侧重点

本篇文章作为年度技术洞察类文章,今天的重点不是deepseek的训练与推理,如果对训练和推理感兴趣,我在年中写过一篇训练与推理的实战,其中详细讲述了DeepSeek-V2大模型的训练和推理,详细可点击:AI智能体研发之路-模型篇(二):DeepSeek-V2-Chat 训练与推理实战(只需将V2替换为V3,即可体验最新版本deepseek)。今天的重点是更深一个层次,带大家代码级认识MoE混合专家网络技术。

1.2 技术洞察—MoE(Mixture-of-Experts,混合专家网络)

MoE(Mixture-of-Experts) 并不是一个新词,近7-8年间,在我做推荐系统精排模型过程中,业界将MoE技术应用于推荐系统多任务学习,以MMoE(2018,google)、PLE(2020,腾讯)为基石,通过门控网络为多个专家网络加权平均,定义每个专家的重要性,解决多目标、多场景、多任务等问题。近1-2年间,基于MoE思想构建的大模型层出不穷,通过路由网络对多个专家网络进行选择,提升推理效率,经典模型有DeepSeekMoE、Mixtral 8x7B、Flan-MoE等。 

万丈高楼平地起,今天我们不聊空中楼阁,而是带大家实现一个MoE网络,了解MoE代码是怎么构建的,大家可以以此代码为基础,继续垒砖,根据自己的业务场景,创新性的构建自己的专家网络。 

二、MoE(Mixture-of-Experts,混合专家网络)

2.1 技术原理

MoE(Mixture-of-Experts)全称为混合专家网络,主要由多个专家网络、多个任务塔、门控网络构成。核心原理:样本数据分别输入num_experts个专家网络进行推理,每个专家网络实际上是一个前馈神经网络(MLP),输入维度为x,输出维度为output_experts_dim;同时,样本数据输入门控网络,门控网络也是一个MLP(可以为多层,也可以为一层),输出为num_experts个experts专家的概率分布,维度为num_experts(菜用softmax将输出归一化,各个维度加起来和为1);将每个专家网络的输出,基于gate门控网络的softmax加权平均,作为Task的输入,所以Task的输入统一维度均为output_experts_dim。在每次反向传播迭代时,对Gate和num_experts个专家参数进行更新,Gate和专家网络的参数受任务Task A、B共同影响。

  • 专家网络:样本数据分别输入num_experts个专家网络进行推理,每个专家网络实际上是一个前馈神经网络(MLP),输入维度为x,输出维度为output_experts_dim。
  • ​​​​​​​门控网络:样本数据输入门控网络,门控网络也是一个MLP(可以为多层,也可以为一层),输出为num_experts个experts专家的概率分布,维度为num_experts(菜用softmax将输出归一化,各个维度加起来和为1)。
  • 任务网络:将每个专家网络的输出,基于gate门控网络的softmax加权平均,作为Task的输入,Task的输入统一维度均为output_experts_dim。

2.2 技术优缺点

相较于传统的DNN网络,MoE的本质是通过多个专家网络对预估任务共同决策,引入Gate作为专家的裁判,给每一个专家打分,判定哪个专家更加权威。(DeepSeekMoE的Router与Gate类似,区别是Gate为每一个专家赋分,加权平均,Router对专家进行选择,推理速度更快)。相较于传统的DNN网络:

优点:

  • 多个DNN专家网络投票共同决定推理结果,相较于单个DNN网络泛化性更好,准确率更高。
  • Gate网络基于多个Task任务进行反馈收敛,可以学到多个Task任务数据的平衡性。

缺点: 

  • 朴素的MoE仅使用了一个Gate网络,虽然Gate网络由多个Task任务共同收敛学习得到,具有一定的平衡性,但对于每个Task的个性化能力仍然不足。(Google针对此缺点发布了MMoE)
  • 底层多个专家网络均为共享专家,输入均为样本数据,参数的差异主要由初始化的不同得到,并不具备特异性。(腾讯针对此缺点发布了PLE)
  • 输入Input均为全部样本数据,学不出不同场景任务的差异性,需要在输入层对场景特征进行拆分(阿里针对此缺点发布了ESMM)

2.3 业务代码实践

2.3.1 业务场景与建模

我们仍然以小红书推荐场景为例,用户在一级发现页场景中停留并点击了“误杀3”中的一个视频笔记,在二级场景视频播放页中观看并点赞了视频。

我们构建一个100维特征输入,4个experts专家网络,2个task任务的,1个门控的MoE网络,用于建模跨场景多任务学习问题,模型架构图如下:

​​​​​​​

如架构图所示,其中有几个注意的点:

  1. num_experts:门控gate的输出维度和专家数相同,均为num_experts,因为gate的用途是对专家网络最后一层进行加权平均,gate维度与专家数是直接对应关系。
  2. output_experts_dim:专家网络的输出维度和task网络的输入维度相同,task网络承接的是专家网络各维度的加权平均值,experts网络与task网络是直接对应关系。
  3. Softmax:Gate门控网络对最后一层采用Softmax归一化,保证专家网络加权平均后值域相同

2.3.2 模型代码实现

基于pytorch,实现上述网络架构,如下:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDatasetclass MoEModel(nn.Module):def __init__(self, input_dim, experts_hidden1_dim, experts_hidden2_dim, output_experts_dim, task_hidden1_dim, task_hidden2_dim, output_task1_dim, output_task2_dim, gate_hidden1_dim, gate_hidden2_dim, num_experts):super(MoEModel, self).__init__()self.num_experts = num_expertsself.output_experts_dim = output_experts_dim# 初始化多个专家网络self.experts = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, experts_hidden1_dim),nn.ReLU(),nn.Linear(experts_hidden1_dim, experts_hidden2_dim),nn.ReLU(),nn.Linear(experts_hidden2_dim, output_experts_dim),nn.ReLU()) for _ in range(num_experts)])# 定义任务1的输出层self.task1_head = nn.Sequential(nn.Linear(output_experts_dim, task_hidden1_dim),nn.ReLU(),nn.Linear(task_hidden1_dim, task_hidden2_dim),nn.ReLU(),nn.Linear(task_hidden2_dim, output_task1_dim),nn.Sigmoid()) # 定义任务2的输出层self.task2_head = nn.Sequential(nn.Linear(output_experts_dim, task_hidden1_dim),nn.ReLU(),nn.Linear(task_hidden1_dim, task_hidden2_dim),nn.ReLU(),nn.Linear(task_hidden2_dim, output_task2_dim),nn.Sigmoid()) # 初始化门控网络self.gating_network = nn.Sequential(nn.Linear(input_dim, gate_hidden1_dim),nn.ReLU(),nn.Linear(gate_hidden1_dim, gate_hidden2_dim),nn.ReLU(),nn.Linear(gate_hidden2_dim, num_experts),nn.Softmax(dim=1))def forward(self, x):# 计算输入数据通过门控网络后的权重gates = self.gating_network(x)#print(gates)batch_size, _ = x.shapetask1_inputs = torch.zeros(batch_size, self.output_experts_dim)task2_inputs = torch.zeros(batch_size, self.output_experts_dim)# 计算每个专家的输出并加权求和for i in range(self.num_experts):expert_output = self.experts[i](x)task1_inputs += expert_output * gates[:, i].unsqueeze(1)task2_inputs += expert_output * gates[:, i].unsqueeze(1)task1_outputs = self.task1_head(task1_inputs)task2_outputs = self.task2_head(task2_inputs)return task1_outputs, task2_outputs# 实例化模型对象
num_experts = 4  # 假设有4个专家
experts_hidden1_dim = 64
experts_hidden2_dim = 32
output_experts_dim = 16
gate_hidden1_dim = 16
gate_hidden2_dim = 8
task_hidden1_dim = 32
task_hidden2_dim = 16
output_task1_dim = 3
output_task2_dim = 2# 构造虚拟样本数据
torch.manual_seed(42)  # 设置随机种子以保证结果可重复
input_dim = 10
num_samples = 1024
X_train = torch.randint(0, 2, (num_samples, input_dim)).float()
y_train_task1 = torch.rand(num_samples, output_task1_dim)  # 假设任务1的输出维度为5
y_train_task2 = torch.rand(num_samples, output_task2_dim)  # 假设任务2的输出维度为3# 创建数据加载器
train_dataset = TensorDataset(X_train, y_train_task1, y_train_task2)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)model = MoEModel(input_dim, experts_hidden1_dim, experts_hidden2_dim, output_experts_dim, task_hidden1_dim, task_hidden2_dim, output_task1_dim, output_task2_dim, gate_hidden1_dim, gate_hidden2_dim, num_experts)# 定义损失函数和优化器
criterion_task1 = nn.MSELoss()
criterion_task2 = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练循环
num_epochs = 100
for epoch in range(num_epochs):model.train()running_loss = 0.0for batch_idx, (X_batch, y_task1_batch, y_task2_batch) in enumerate(train_loader):# 前向传播: 获取预测值#print(batch_idx, X_batch )#print(f'Epoch [{epoch+1}/{num_epochs}-{batch_idx}], Loss: {running_loss/len(train_loader):.4f}')outputs_task1, outputs_task2 = model(X_batch)# 计算每个任务的损失loss_task1 = criterion_task1(outputs_task1, y_task1_batch)loss_task2 = criterion_task2(outputs_task2, y_task2_batch)total_loss = loss_task1 + loss_task2# 反向传播和优化optimizer.zero_grad()total_loss.backward()optimizer.step()running_loss += total_loss.item()if epoch % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')print(model)
for param_tensor in model.state_dict():print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# 模型预测
model.eval()
with torch.no_grad():test_input = torch.randint(0, 2, (1, input_dim)).float()  # 构造一个测试样本pred_task1, pred_task2 = model(test_input)print(f'一级场景预测结果: {pred_task1}')print(f'二级场景预测结果: {pred_task2}')

2.3.3 模型训练与推理测试

运行上述代码,模型启动训练,Loss逐渐收敛,测试结果如下:

2.3.4 打印模型结构 

使用print(model)打印模型结构如下

三、总结

本文代码级脚踏实地讲解了DeepSeek大模型、MMoE推荐模型中的MoE(Mixture-of-Experts)技术,该技术的主要思想是通过门控(gate)或路由(router)网络,对多个专家进行加权平均或筛选,将一个DNN网络裂变为多个DNN网络后,投票决定预测结果,相较于单一的DNN网络,具有更强的容错性、泛化性与准确性,同时可以提高推理速度,节省推理资源。

技术洞察结论:MoE技术未来将成为大模型和推荐系统进一步突破的关键技术,个人认为该技术为2024年算法基础技术中的SOTA,但其实并没有那么神秘,通过本篇文章,可以试着动手实现一个MoE,再基于自己的业务场景,对齐专家网络、门控网络、任务网络进行创新,期待本篇文章对您有帮助!

如果您还有时间,欢迎阅读本专栏的其他文章:

【深度学习】多目标融合算法(一):样本Loss加权(Sample Loss Reweight)

【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model) ​​​​​​​

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/69640.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

防火墙是什么?详解网络安全的关键守护者

当今信息化时代,企业和个人在享受数字生活带来的便利时,也不可避免地面对各种潜在的风险。防火墙作为网络安全体系中的核心组件,就像一道牢不可破的防线,保护着我们的数据和隐私不受外界威胁的侵害。那么防火墙是什么?…

Windows系统下设置Vivado默认版本:让工程文件按需打开

在FPGA开发过程中,我们常常需要在一台电脑上安装多个不同版本的Vivado软件,以满足不同项目的需求。然而,当双击打开一个Vivado工程文件(.xpr)时,系统默认会调用一个固定的版本,这可能并不是我们…

DeepSeek模型架构及优化内容

DeepSeek v1版本 模型结构 DeepSeek LLM基本上遵循LLaMA的设计: 采⽤Pre-Norm结构,并使⽤RMSNorm函数. 利⽤SwiGLU作为Feed-Forward Network(FFN)的激活函数,中间层维度为8/3. 去除绝对位置编码,采⽤了…

蓝桥杯---N字形变换(leetcode第6题)题解

文章目录 1.问题重述2.例子分析3.思路讲解4.代码分析 1.问题重述 这个题目可以是Z字形变换,也可以叫做N字形变换: 给定我们一串字符,我们需要把这串字符按照先往下写,再往右上方去写,再往下去写,再往右上…

vscode无法ssh连接远程机器解决方案

远程服务器配置问题 原因:远程服务器的 SSH 服务配置可能禁止了 TCP 端口转发功能,或者 VS Code Server 在远程服务器上崩溃。 解决办法 检查 SSH 服务配置:登录到远程服务器,打开 /etc/ssh/sshd_config 文件,确保以下…

LogicFlow自定义节点:矩形、HTML(vue3)

效果: LogicFlow 内部是基于MVVM模式进行开发的,分别使用preact和mobx来处理 view 和 model,所以当我们自定义节点的时候,需要为这个节点定义view和model。 参考官方文档:节点 | LogicFlow 1、自定义矩形节点 custo…

19.3 连接数据库

版权声明:本文为博主原创文章,转载请在显著位置标明本文出处以及作者网名,未经作者允许不得用于商业目的。 ​​​​​​​需要北风数据库的请留言自己的信箱。 连接数据库使用OleDbConnection(数据连接)类&#xff…

19.2 C#数据库操作概览

版权声明:本文为博主原创文章,转载请在显著位置标明本文出处以及作者网名,未经作者允许不得用于商业目的。 需要北风数据库的请留言自己的信箱。 C#对数据的处理主要集中在System.Data命名空间。 对数据操作会使用到以下几个类&#xff1a…

YOLOv11实时目标检测 | 摄像头视频图片文件检测

在上篇文章中YOLO11环境部署 || 从检测到训练https://blog.csdn.net/2301_79442295/article/details/145414103#comments_36164492,我们详细探讨了YOLO11的部署以及推理训练,但是评论区的观众老爷就说了:“博主博主,你这个只能推理…

JavaEE架构

一.架构选型 1.VM架构 VM架构通常指的是虚拟机(Virtual Machine)的架构。虚拟机是一种软件实现的计算机系统,它模拟了物理计算机的功能,允许在单一物理硬件上运行多个操作系统实例。虚拟机架构主要包括以下几个关键组件&#xff…

[笔记] 汇编杂记(持续更新)

文章目录 前言举例解释函数的序言函数的调用栈数据的传递 总结 前言 举例解释 // Type your code here, or load an example. int square(int num) {return num * num; }int sub(int num1, int num2) {return num1 - num2; }int add(int num1, int num2) {return num1 num2;…

如何在Linux中设置定时任务(cron)

在Linux系统中,定时任务是自动执行任务的一种非常方便的方式,常常用于定期备份数据、更新系统或清理日志文件等操作。cron是Linux下最常用的定时任务管理工具,它允许用户根据设定的时间间隔自动执行脚本和命令。在本文中,我们将详…

【MySQL】我在广州学Mysql 系列—— 数据备份与还原

ℹ️大家好,我是练小杰,今天周一,过两天就是元宵节了,今年元宵节各位又要怎么过呢!! 本文主要对Mysql数据库中的数据备份与还原内容进行讨论!! 回顾:👉【MySQ…

【redis】数据类型之hash

Redis中的Hash数据类型是一种用于存储键值对集合的数据结构。与Redis的String类型不同,Hash类型允许你将多个字段(field)和值(value)存储在一个单独的key下,从而避免了将多个相关数据存储为多个独立的key。…

element-plus 解决el-dialog背后的页面滚动问题,及其内容有下拉框出现错位问题

这个问题通常是因为 el‑dialog 默认会锁定 body 的滚动&#xff08;通过给 body 添加隐藏滚动条的样式&#xff09;&#xff0c;从而导致页面在打开对话框时跳转到顶部。解决方法是在使用 el‑dialog 时禁用锁定滚动功能。 <el-dialogv-model"dialogVisible":lo…

SpringBoot+Dubbo+zookeeper 急速入门案例

项目目录结构&#xff1a; 第一步&#xff1a;创建一个SpringBoot项目&#xff0c;这里选择Maven项目或者Spring Initializer都可以&#xff0c;这里创建了一个Maven项目&#xff08;SpringBoot-Dubbo&#xff09;&#xff0c;pom.xml文件如下&#xff1a; <?xml versio…

游戏引擎学习第96天

讨论了优化和速度问题&#xff0c;以便简化调试过程 节目以一个有趣的类比开始&#xff0c;提到就像某些高端餐厅那样&#xff0c;菜单上充满了听起来陌生或不太清楚的描述&#xff0c;需要依靠服务员进一步解释。虽然这听起来有些奇怪&#xff0c;但实际上&#xff0c;它反映…

Spring Boot + MyBatis Field ‘xxx‘ doesn‘t have a default value 问题排查与解决

目录 1. 问题所示2. 原理分析3. 解决方法1. 问题所示 执行代码的时候,出现某个字段无法添加 ### Error updating database. Cause: java.sql.SQLException: Field e_f_id doesnt have a default value ### The error may exist in cn

【分布式理论9】分布式协同:分布式系统进程互斥与互斥算法

文章目录 一、互斥问题及分布式系统的特性二、分布式互斥算法1. 集中互斥算法调用流程优缺点 2. 基于许可的互斥算法&#xff08;Lamport 算法&#xff09;调用流程优缺点 3. 令牌环互斥算法调用流程优缺点 三、三种算法对比 在分布式系统中&#xff0c;多个应用服务可能会同时…

redo和binlog区别

事务是数据库区别于文件系统的最重要功能&#xff0c;数据库事务支持ACID四个特性&#xff0c;其中I&#xff1a;隔离性是通过锁的方式实现的&#xff0c;剩下的A&#xff1a;原子性 C&#xff1a;一致性 D&#xff1a;持久性是通过redo日志、undo日志、binlog日志来实现的。 我…