手写系列——MoE网络

参考:

MOE原理解释及从零实现一个MOE(专家混合模型)_moe代码-CSDN博客

MoE环游记:1、从几何意义出发 - 科学空间|Scientific Spaces 

深度学习之图像分类(二十八)-- Sparse-MLP(MoE)网络详解_sparse moe-CSDN博客

深度学习之图像分类(二十九)-- Sparse-MLP网络详解_sparse mlp-CSDN博客 

 

代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 超参数设置
num_experts = 4      # 专家数量
top_k = 2            # 激活专家数
# input_dim = 3072     # CIFAR-10图像展平后维度(32x32x3)
input_dim = 64 * 8 * 8
hidden_dim = 512     # 专家网络隐藏层维度
num_classes = 10     # 分类类别数# MoE层实现(文献[5][7])
class SparseMoE(nn.Module):def __init__(self):super().__init__()self.experts = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim)) for _ in range(num_experts)])self.gate = nn.Sequential(nn.Linear(input_dim, num_experts),nn.Softmax(dim=1))# 负载均衡参数(文献[4][7])self.balance_loss_weight = 0.01self.register_buffer('expert_counts', torch.zeros(num_experts))def forward(self, x):# 门控计算gate_scores = self.gate(x)  # [B, num_experts]# Top-k选择(文献[5])topk_scores, topk_indices = torch.topk(gate_scores, top_k, dim=1)mask = F.one_hot(topk_indices, num_experts).float().sum(dim=1)# 专家输出聚合expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)selected_experts = expert_outputs.gather(1, topk_indices.unsqueeze(-1).expand(-1, -1, hidden_dim))  # [B, 2, H]# print(f"专家输出维度: {expert_outputs.shape}")# print(f"选择索引维度: {topk_indices.shape}")# print(f"选择专家维度: {selected_experts.shape}")weighted_outputs = (selected_experts  * topk_scores.unsqueeze(-1)).sum(dim=1)# 更新专家使用统计self.expert_counts += mask.sum(dim=0)return weighted_outputsdef balance_loss(self):# 计算负载均衡损失(文献[4][7])expert_probs = self.expert_counts / self.expert_counts.sum()balance_loss = torch.std(expert_probs) * self.balance_loss_weightself.expert_counts.zero_()  # 重置计数器return balance_loss# 完整模型架构(文献[2][6])
class MoEImageClassifier(nn.Module):def __init__(self):super().__init__()self.feature_extractor = nn.Sequential(nn.Conv2d(3, 32, 3, padding=1),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(32, 64, 3, padding=1),nn.ReLU(),nn.MaxPool2d(2))self.moe_layer = SparseMoE()self.classifier = nn.Linear(hidden_dim, num_classes)def forward(self, x):x = self.feature_extractor(x)x = x.view(x.size(0), -1)  # 展平特征x = self.moe_layer(x)return self.classifier(x)# 数据预处理(文献[2])
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)# 训练流程
model = MoEImageClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)for epoch in range(10):for images, labels in train_loader:optimizer.zero_grad()outputs = model(images)main_loss = criterion(outputs, labels)balance_loss = model.moe_layer.balance_loss()total_loss = main_loss + balance_losstotal_loss.backward()optimizer.step()print(f'Epoch [{epoch+1}/10], Loss: {total_loss.item():.4f}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/71837.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux的基础指令和环境部署,项目部署实战(下)

目录 上一篇:Linxu的基础指令和环境部署,项目部署实战(上)-CSDN博客 1. 搭建Java部署环境 1.1 apt apt常用命令 列出所有的软件包 更新软件包数据库 安装软件包 移除软件包 1.2 JDK 1.2.1. 更新 1.2.2. 安装openjdk&am…

【蓝桥杯】第十五届省赛大学真题组真题解析

【蓝桥杯】第十五届省赛大学真题组真题解析 一、智能停车系统 1、知识点 (1)flex-wrap 控制子元素的换行方式 属性值有: no-wrap不换行wrap伸缩容器不够则自动往下换行wrap-reverse伸缩容器不够则自动往上换行 (2&#xff0…

flink operator v1.10对接华为云对象存储OBS

1 概述 flink operator及其flink集群,默认不直接支持华为云OBS,需要在这些java程序的插件目录放一个jar包,以及修改flink配置后,才能支持集成华为云OBS。 相关链接参考: https://support.huaweicloud.com/bestpracti…

免费PDF工具

Smallpdf.com - A Free Solution to all your PDF Problems Smallpdf - the platform that makes it super easy to convert and edit all your PDF files. Solving all your PDF problems in one place - and yes, free. https://smallpdf.com/#rappSmallpdf.com-解决您所有PD…

去中心化技术P2P框架

中心化网络与去中心化网络 1. 中心化网络 在传统的中心化网络中,所有客户端都通过一个中心服务器进行通信。这种网络拓扑结构通常是一个星型结构,其中服务器作为中心节点,每个客户端只能与服务器通信。如果客户端之间需要通信,必须…

muduo源码阅读:linux timefd定时器

⭐timerfd timerfd 是Linux一个定时器接口,它基于文件描述符工作,并通过该文件描述符的可读事件进行超时通知。可以方便地与select、poll和epoll等I/O多路复用机制集成,从而在没有处理事件时阻塞程序执行,实现高效的零轮询编程模…

Pinia 3.0 正式发布:全面拥抱 Vue 3 生态,升级指南与实战教程

一、重大版本更新解析 2024年2月11日,Vue 官方推荐的状态管理库 Pinia 迎来 3.0 正式版发布,本次更新标志着其全面转向 Vue 3 技术生态。以下是开发者需要重点关注的升级要点: 1.1 核心变更说明 特性3.0 版本要求兼容性说明Vue 支持Vue 3.…

【图像处理 --- Sobel 边缘检测的详解】

Sobel 边缘检测的详解 目录 Sobel 边缘检测的详解1. 梯度计算2. 梯度大小3. 梯度方向4. 非极大值抑制5. 双阈值处理6. 在 MATLAB 中实现 Sobel 边缘检测7.运行结果展示8.关键参数解释9.实验与验证 Sobel 边缘检测是一种经典的图像处理算法,用于检测图像中的边缘。它…

LeetCode 热题100 15. 三数之和

LeetCode 热题100 | 15. 三数之和 大家好,今天我们来解决一道经典的算法题——三数之和。这道题在 LeetCode 上被标记为中等难度,要求我们从一个整数数组中找到所有不重复的三元组,使得三元组的和为 0。下面我将详细讲解解题思路&#xff0c…

基因组组装中的术语1——from HGP

Initial sequencing and analysis of the human genome | Nature 1,分层鸟枪法测序hierarchical shotgun sequencing

安全开发-环境选择

文章目录 个人心得虚拟机选择ubuntu 22.04python环境选择conda下载使用: 个人心得 在做开发时配置一个专门的环境可以使我们在开发中的效率显著提升,可以避免掉很多环境冲突的报错。尤其是python各种版本冲突,还有做渗透工具不要选择windows…

数字体验驱动用户参与增效路径

内容概要 在数字化转型深化的当下,数字内容体验已成为企业与用户建立深度连接的核心切入点。通过个性化推荐引擎与智能数据分析系统的协同运作,企业能够实时捕捉用户行为轨迹,构建精准的用户行为深度洞察模型。这一模型不仅支撑内容分发的动…

Python 字符串(str)全方位剖析:从基础入门、方法详解到跨语言对比与知识拓展

Python 字符串(str)全方位剖析:从基础入门、方法详解到跨语言对比与知识拓展 本文将深入探讨 Python 中字符串(str)的相关知识,涵盖字符串的定义、创建、基本操作、格式化等内容。同时,会将 Py…

使用C++实现简单的TCP服务器和客户端

使用C实现简单的TCP服务器和客户端 介绍准备工作1. TCP服务器实现代码结构解释 2. TCP客户端实现代码结构解释 3. 测试1.编译:2.运行 结语 介绍 本文将通过一个简单的例子,介绍如何使用C实现一个基本的TCP服务器和客户端。这个例子展示了如何创建服务器…

Java Web开发实战与项目——Spring Boot与Spring Cloud微服务项目实战

企业级应用中,微服务架构已经成为一种常见的开发模式。Spring Boot与Spring Cloud提供了丰富的工具和组件,帮助开发者快速构建、管理和扩展微服务应用。本文将通过一个实际的微服务项目,展示如何使用Spring Boot与Spring Cloud构建微服务架构…

VMware建立linux虚拟机

本文适用于初学者,帮助初学者学习如何创建虚拟机,了解在创建过程中各个选项的含义。 环境如下: CentOS版本: CentOS 7.9(2009) 软件: VMware Workstation 17 Pro 17.5.0 build-22583795 1.配…

Linux8-互斥锁、信号量

一、前情回顾 void perror(const char *s);功能:参数: 二、资源竞争 1.多线程访问临界资源时存在资源竞争(存在资源竞争、造成数据错乱) 临界资源:多个线程可以同时操作的资源空间(全局变量、共享内存&a…

LD_PRELOAD 绕过 disable_function 学习

借助这位师傅的文章来学习通过LD_PRELOAD来绕过disable_function的原理 【PHP绕过】LD_PRELOAD bypass disable_functions_phpid绕过-CSDN博客 感谢这位师傅的贡献 介绍 静态链接: (1)举个情景来帮助理解: 假设你要搬家&#x…

【无人集群系列---无人机集群编队算法】

【无人集群系列---无人机集群编队算法】 一、核心目标二、主流编队控制方法1. 领航-跟随法(Leader-Follower)2. 虚拟结构法(Virtual Structure)3. 行为法(Behavior-Based)4. 人工势场法(Artific…

Oracle Fusion Middleware更改weblogic密码

前言 当用户忘记weblogic密码时,且无法登录到web界面中,需要使用服务器命令更改密码 更改方式 1、备份 首先进入 weblogic 安装目录,备份三个文件:boot.properties,DefaultAuthenticatorInit.ldift,Def…