【深度学习】多目标融合算法(五):定制门控网络CGC(Customized Gate Control)

目录

一、引言

二、CGC(Customized Gate Control,定制门控网络)

2.1 技术原理

2.2 技术优缺点

2.3 业务代码实践

2.3.1 业务场景与建模

2.3.2 模型代码实现

2.3.3 模型训练与推理测试

2.3.4 打印模型结构 

三、总结


一、引言

上一篇我们讲了MMoE多任务网络,通过对每一个任务塔建立Gate门控,对专家网络进行加权平均,Gate门控起到了对多个共享专家重要度筛选的作用。在每轮反向传播时,每个任务tower分别更新对应Gate的参数,以及共享专家的参数。模型主要起到了多目标任务平衡的作用。

今天我们重点将CGC(Customized Gate Control)定制门控网络,核心思想是在MMoE基础上,为每一个任务tower定制独享专家,实用任务独享专家与共享专家共同决定任务Tower的输入,相比于MMoE仅用Gate门控表征任务Tower的方法,CGC引入独享专家,对任务表征更加全面,又通过共享专家保证关联性。

二、CGC(Customized Gate Control,定制门控网络)

2.1 技术原理

CGC(Customized Gate Control)全称为定制门控网络,主要由多个任务塔、对应多组独享专家网络,对应多个门控网络以及一组共享专家网络,专家网络组内可以包含多个专家MLP。核心原理:样本input分别输入共享专家MLP、独立专家MLP、独立专家对应门控网络,门控网络输出为经过softmax的权重分布,维度对应共享专家数num_shared_experts和独立专家数num_task_experts的和,通过对独立专家输出和共享专家输出采用Gate门控加权平均后, 输入到对应的任务Tower。每个任务Tower输入自己对应的独享专家、共享专家、门控加权平均的输入。反向传播时,每个任务更新自己独享专家、独享门控以及共享专家的参数。

  • 共享专家网络:样本数据分别输入num_shared_experts个专家网络进行推理,每个共享专家网络实际上是一个多层感知机(MLP),输入维度为x,输出维度为output_experts_dim。
  • 独享专家网络:样本数据分别输入num_task_experts个专家网络进行推理,每个共享专家网络实际上是一个多层感知机(MLP),输入维度为x,输出维度为output_experts_dim。
  • 门控网络:样本数据输出各自任务对应的门控网络,每个门控网络可以是一个多层感知机,也可以是一个双层的交叉,主要是为了输出专家网络的加权平均权重。
  • 任务网络:对于每一个Task,将各自对应num_shared_experts个共享专家和num_task_experts个独立专家,基于对应gate门控网络的softmax加权平均,作为各自Task的输入,所有Task的输入统一维度均为output_experts_dim。

2.2 技术优缺点

相较于MMoE网络,CGC为每一个任务tower定制独享专家,实用任务独享专家与共享专家共同决定任务Tower的输入,相比于MMoE仅用Gate门控表征任务Tower的方法,CGC引入独享专家,对任务表征更加全面,又通过共享专家保证关联性。

优点:

  • 切断任务tower与其他任务独享专家的联系,使得独享专家能够更专注的学习本任务内的知识与信息。比如切断互动塔与点击专家的联系,只和互动专家同时迭代,让互动目标的学习更加纯粹。
  • 独享专家只受对应任务梯度的影响,不受其他任务梯度的影响,而共享专家可以被多个任务梯度同时更新。
  • 本质上,CGC就是在MMoE上新增了独享专家,MMoE仅有共享专家。

缺点: 

  • 相较于PLE、SNR等,没有学习到专家与专家之间的相互关系,层级堆叠不够。
  • 相较于DeepSeekMoE的路由方法,CGC还是过于定制化与单一话,专家组合不足。

2.3 业务代码实践

2.3.1 业务场景与建模

我们还是以小红书推荐场景为例,针对一个视频,用户可以点红心(互动),也可以点击视频进行播放(点击),针对互动和点击两个目标进行多目标建模

我们构建一个100维特征输入,1组共享专家网络(含2个共享专家),2组独享专家网络(各含2个独享专家),2个门控,2个任务塔的CGC网络,用于建模多目标学习问题,模型架构图如下:

​​​​​​​​​​​​​​

如架构图所示,其中有几个注意的点:

  • num_shared_experts+num_task_expertsGate的维度等于共享专家的维度加上任务独享专家的维度。
  • output_experts_dim:共享专家、独享专家网络的输出维度和task网络的输入维度相同,task网络承接的是专家网络各维度的加权平均值,experts网络与task网络是直接对应关系。
  • Softmax:Gate门控网络对共享专家和独享专家的偏好权重采用Softmax归一化,保证专家网络加权平均后值域相同

2.3.2 模型代码实现

基于pytorch,实现上述CGC网络架构,如下:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDatasetclass CGCModel(nn.Module):def __init__(self, input_dim, experts_hidden1_dim, experts_hidden2_dim, output_experts_dim, task_hidden1_dim, task_hidden2_dim, output_task1_dim, output_task2_dim, gate_hidden1_dim, gate_hidden2_dim, num_shared_experts, num_task_experts):super(CGCModel, self).__init__()# 初始化函数外使用初始化变量需要赋值,否则默认使用全局变量# 初始化函数内使用初始化变量不需要赋值 self.num_shared_experts = num_shared_expertsself.num_task_experts = num_task_expertsself.output_experts_dim = output_experts_dim# 初始化共享专家self.shared_experts_2 = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, experts_hidden1_dim),nn.ReLU(),nn.Linear(experts_hidden1_dim, experts_hidden2_dim),nn.ReLU(),nn.Linear(experts_hidden2_dim, output_experts_dim),nn.ReLU()) for _ in range(num_shared_experts)])# 初始化任务1专家self.task1_experts_2 = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, experts_hidden1_dim),nn.ReLU(),nn.Linear(experts_hidden1_dim, experts_hidden2_dim),nn.ReLU(),nn.Linear(experts_hidden2_dim, output_experts_dim),nn.ReLU()) for _ in range(num_task_experts)])# 初始化任务2专家self.task2_experts_2 = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, experts_hidden1_dim),nn.ReLU(),nn.Linear(experts_hidden1_dim, experts_hidden2_dim),nn.ReLU(),nn.Linear(experts_hidden2_dim, output_experts_dim),nn.ReLU()) for _ in range(num_task_experts)])# 初始化门控网络任务1self.gating1_network_2 = nn.Sequential(nn.Linear(input_dim, gate_hidden1_dim),nn.ReLU(),nn.Linear(gate_hidden1_dim, gate_hidden2_dim),nn.ReLU(),nn.Linear(gate_hidden2_dim, num_shared_experts+num_task_experts),nn.Softmax(dim=1))# 初始化门控网络任务2self.gating2_network_2 = nn.Sequential(nn.Linear(input_dim, gate_hidden1_dim),nn.ReLU(),nn.Linear(gate_hidden1_dim, gate_hidden2_dim),nn.ReLU(),nn.Linear(gate_hidden2_dim, num_shared_experts+num_task_experts),nn.Softmax(dim=1))# 定义任务1的输出层self.task1_head = nn.Sequential(nn.Linear(output_experts_dim, task_hidden1_dim),nn.ReLU(),nn.Linear(task_hidden1_dim, task_hidden2_dim),nn.ReLU(),nn.Linear(task_hidden2_dim, output_task1_dim),nn.Sigmoid()) # 定义任务2的输出层self.task2_head = nn.Sequential(nn.Linear(output_experts_dim, task_hidden1_dim),nn.ReLU(),nn.Linear(task_hidden1_dim, task_hidden2_dim),nn.ReLU(),nn.Linear(task_hidden2_dim, output_task2_dim),nn.Sigmoid()) def forward(self, x):gates1 = self.gating1_network_2(x)gates2 = self.gating2_network_2(x)#定义专家网络输出作为任务塔输入batch_size, _ = x.shapetask1_inputs = torch.zeros(batch_size, self.output_experts_dim)task2_inputs = torch.zeros(batch_size, self.output_experts_dim)for i in range(self.num_shared_experts):task1_inputs += self.shared_experts_2[i](x) * gates1[:, i].unsqueeze(1) + self.task1_experts_2[i](x) * gates1[:, i+self.num_shared_experts].unsqueeze(1)task2_inputs += self.shared_experts_2[i](x) * gates2[:, i].unsqueeze(1) + self.task2_experts_2[i](x) * gates2[:, i+self.num_shared_experts].unsqueeze(1)task1_outputs = self.task1_head(task1_inputs)task2_outputs = self.task2_head(task2_inputs)return task1_outputs, task2_outputs# 实例化模型对象
experts_hidden1_dim = 64
experts_hidden2_dim = 32
output_experts_dim = 16
gate_hidden1_dim = 16
gate_hidden2_dim = 8
task_hidden1_dim = 32
task_hidden2_dim = 16
output_task1_dim = 1
output_task2_dim = 1
num_shared_experts = 2
num_task_experts = 2# 构造虚拟样本数据
torch.manual_seed(42)  # 设置随机种子以保证结果可重复
input_dim = 100
num_samples = 1024
X_train = torch.randint(0, 2, (num_samples, input_dim)).float()
y_train_task1 = torch.rand(num_samples, output_task1_dim)  # 假设任务1的输出维度为1
y_train_task2 = torch.rand(num_samples, output_task2_dim)  # 假设任务2的输出维度为1# 创建数据加载器
train_dataset = TensorDataset(X_train, y_train_task1, y_train_task2)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)model = CGCModel(input_dim, experts_hidden1_dim, experts_hidden2_dim, output_experts_dim, task_hidden1_dim, task_hidden2_dim, output_task1_dim, output_task2_dim, gate_hidden1_dim, gate_hidden2_dim, num_shared_experts, num_task_experts)# 定义损失函数和优化器
criterion_task1 = nn.MSELoss()
criterion_task2 = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练循环
num_epochs = 100
for epoch in range(num_epochs):model.train()running_loss = 0.0for batch_idx, (X_batch, y_task1_batch, y_task2_batch) in enumerate(train_loader):# 前向传播: 获取预测值#print(batch_idx, X_batch )#print(f'Epoch [{epoch+1}/{num_epochs}-{batch_idx}], Loss: {running_loss/len(train_loader):.4f}')outputs_task1, outputs_task2 = model(X_batch)# 计算每个任务的损失loss_task1 = criterion_task1(outputs_task1, y_task1_batch)loss_task2 = criterion_task2(outputs_task2, y_task2_batch)total_loss = loss_task1 + loss_task2# 反向传播和优化optimizer.zero_grad()total_loss.backward()optimizer.step()running_loss += total_loss.item()if epoch % 10 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')print(model)
#for param_tensor in model.state_dict():
#    print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# 模型预测
model.eval()
with torch.no_grad():test_input = torch.randint(0, 2, (1, input_dim)).float()  # 构造一个测试样本pred_task1, pred_task2 = model(test_input)print(f'互动目标预测结果: {pred_task1}')print(f'点击目标预测结果: {pred_task2}')

相比于上一篇MMoE中的代码,CGC复杂了很多,新增了2组独享专家,且在门控与独享、共享专家加权平均计算的时候需要进行处理,很容易出问题。

2.3.3 模型训练与推理测试

运行上述代码,模型启动训练,Loss逐渐收敛,测试结果如下:

2.3.4 打印模型结构 ​​​​​​​

三、总结

本文详细介绍了CGC多任务模型的算法原理、算法优势,他是下一篇PLE多层多任务模型的基础,并以小红书业务场景为例,构建CGC网络结构并使用pytorch代码实现对应的网络结构、训练流程。相比于MMoE,CGC新增独享专家网络,通过gate门控的串联,切断任务Tower与其他任务独享专家的联系,使得独享专家能够更专注的学习本任务内的知识与信息。

如果您还有时间,欢迎阅读本专栏的其他文章:

【深度学习】多目标融合算法(一):样本Loss加权(Sample Loss Reweight)

【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model) ​​​​​​​

【深度学习】多目标融合算法(三):混合专家网络MOE(Mixture-of-Experts) 

 【深度学习】多目标融合算法(四):多门混合专家网络MMOE(Multi-gate Mixture-of-Experts)​​​​​​​

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/73852.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

在线pdf处理网站合集

1、PDF24 Tools:https://tools.pdf24.org/zh/ 2、PDF派:https://www.pdfpai.com/ 3、ALL TO ALL:https://www.alltoall.net/ 4、CleverPDF:https://www.cleverpdf.com/cn 5、Doc Small:https://docsmall.com/ 6、Aconv…

网络编程-实现客户端通信

#include <stdio.h> #include <stdlib.h> #include <string.h> #include <unistd.h> #include <sys/socket.h> #include <netinet/in.h> #include <sys/select.h>#define MAX_CLIENTS 2 // 最大客户端连接数 #define BUFFER_SI…

力扣100二刷——图论、回溯

第二次刷题不在idea写代码&#xff0c;而是直接在leetcode网站上写&#xff0c;“逼”自己掌握常用的函数。 标志掌握程度解释办法⭐Fully 完全掌握看到题目就有思路&#xff0c;编程也很流利⭐⭐Basically 基本掌握需要稍作思考&#xff0c;或者看到提示方法后能解答⭐⭐⭐Sl…

【大模型实战篇】多模态推理模型Skywork-R1V

1. 背景介绍 近期昆仑万维开源的Skywork R1V模型&#xff0c;是基于InternViT-6B-448px-V2_5以及deepseek-ai/DeepSeek-R1-Distill-Qwen-32B 通过强化学习得到。当然语言模型也可以切换成QwQ-32B。因此该模型最终的参数量大小为38B。 该模型具备多模态推理能力&#xf…

识别并脱敏上传到deepseek/chatgpt的文本文件中的护照信息

本文将介绍一种简单高效的方法解决用户在上传文件到DeepSeek、ChatGPT&#xff0c;文心一言&#xff0c;AI等大语言模型平台过程中的护照号识别和脱敏问题。 DeepSeek、ChatGPT&#xff0c;Qwen&#xff0c;Claude等AI平台工具快速的被接受和使用&#xff0c;用户每天上传的文…

数据驱动进化:AI Agent如何重构手机交互范式?

如果说AIGC拉开了内容生成的序幕&#xff0c;那么AI Agent则标志着AI从“工具”向“助手”的跨越式进化。它不再是简单的问答机器&#xff0c;而是一个能够感知环境、规划任务并自主执行的智能体&#xff0c;更像是虚拟世界中的“全能员工”。 正如行业所热议的&#xff1a;“大…

【AI News | 20250319】每日AI进展

AI Repos 1、XianyuAutoAgent 实现了 24 小时自动化值守的 AI 智能客服系统&#xff0c;支持多专家协同决策、智能议价和上下文感知对话&#xff0c;让我们店铺管理更轻松。主要功能&#xff1a; 智能对话引擎&#xff0c;支持上下文感知和专家路由阶梯降价策略&#xff0c;自…

nginx中间件部署

中间件部署流程 ~高级权限账户安装必要的插件 -> 普通权限账户安装所需要的服务 -> 高级权限账户开启并设置开机自启所安装的服务 -> iptables放行所需要的服务 普通权限账户安装NGINX中间件 1、拥有高级权限的账户安装必要的插件 sudo yum install -y gcc-c make…

C语言自定义类型【结构体】详解,【结构体内存怎么计算】 详解 【热门考点】:结构体内存对齐

引言 详细讲解什么是结构体&#xff0c;结构体的运用&#xff0c; 详细介绍了结构体在内存中占几个字节的计算。 【热门考点】&#xff1a;结构体内存对齐 介绍了&#xff1a;结构体传参 一、什么是结构体&#xff1f; 结构是⼀些值的集合&#xff0c;这些值称为成员变量。结构…

前端应用更新通知机制全解析:构建智能化版本更新策略

引言&#xff1a;数字时代的更新挑战 在持续交付的现代软件开发模式下&#xff0c;前端应用平均每周产生2-3次版本迭代。但据Google研究报告显示&#xff0c;38%的用户在遇到功能异常时仍在使用过期版本的应用。如何优雅地实现版本更新通知&#xff0c;已成为提升用户体验的关…

Apache DolphinScheduler:一个可视化大数据工作流调度平台

Apache DolphinScheduler&#xff08;海豚调度&#xff09;是一个分布式易扩展的可视化工作流任务调度开源系统&#xff0c;适用于企业级场景&#xff0c;提供了一个可视化操作任务、工作流和全生命周期数据处理过程的解决方案。 Apache DolphinScheduler 旨在解决复杂的大数据…

[蓝桥杯 2023 省 B] 飞机降落

[蓝桥杯 2023 省 B] 飞机降落 题目描述 N N N 架飞机准备降落到某个只有一条跑道的机场。其中第 i i i 架飞机在 T i T_{i} Ti​ 时刻到达机场上空&#xff0c;到达时它的剩余油料还可以继续盘旋 D i D_{i} Di​ 个单位时间&#xff0c;即它最早可以于 T i T_{i} Ti​ 时刻…

使用Trae 生成的React版的贪吃蛇

使用Trae 生成的React版的贪吃蛇 首先你想用这个贪吃蛇&#xff0c;你需要先安装Trae Trae 官方地址 他有两种模式 chat builder 我使用的是builder模式,虽然是Alpha.还是可以用。 接下来就是按着需求傻瓜式的操作生成代码 他生成的代码不完全正确&#xff0c;比如没有引入…

AI大模型:(一)1.大模型的发展与局限

说起AI大模型不得不说下机器学习的发展史&#xff0c;机器学习包括传统机器学习、深度学习&#xff0c;而大模型&#xff08;Large Models&#xff09;属于机器学习中的深度学习&#xff08;Deep Learning&#xff09;领域&#xff0c;具体来说&#xff0c;它们通常基于神经网络…

rust学习笔记17-异常处理

今天聊聊rust中异常错误处理 1. 基础类型&#xff1a;Result 和 Option&#xff0c;之前判断空指针就用到过 Option<T> 用途&#xff1a;表示值可能存在&#xff08;Some(T)&#xff09;或不存在&#xff08;None&#xff09;&#xff0c;适用于无需错误信息的场景。 f…

Python:单继承方法的重写

继承&#xff1a;让类和类之间转变为父子关系&#xff0c;子类默认继承父类的属性和方法 单继承&#xff1a; class Person:def eat(self):print("eat")def sing(self):print("sing") class Girl(Person):pass#占位符&#xff0c;代码里面类下面不写任何东…

记录一下aes加密与解密

该文章只做拓展后续会更新&#xff1b;如有出错请指出 首先需要先引入相关依赖 crypto-js 然后直接开始存储 export function aesEncrypt(message: string, key: string) {return aes.encrypt(message, key).toString(); } 之后是解密方式 function decrypt(content: any, key…

[免费]直接整篇翻译pdf工具-支持多种语言

<闲来没事写篇博客填补中文知识库漏洞> 如题&#xff0c;[免费][本地]工具基于开源仓库&#xff1a; 工具 是python&#xff01;太好了&#xff0c;所以各个平台都可以&#xff0c;我这里基于windows. 1. 先把github代码下载下来&#xff1a; git clone https://githu…

UI设计中的用户反馈机制:提升交互体验的关键

hello宝子们...我们是艾斯视觉擅长ui设计和前端数字孪生、大数据、三维建模、三维动画10年经验!希望我的分享能帮助到您!如需帮助可以评论关注私信我们一起探讨!致敬感谢感恩! 在数字化产品泛滥的今天&#xff0c;用户与界面的每一次交互都在无形中塑造着他们对产品的认知。一个…

Hessian 矩阵是什么

Hessian 矩阵是什么 目录 Hessian 矩阵是什么Hessian 矩阵的性质及举例说明**1. 对称性****2. 正定性决定极值类型****特征值为 2(正),因此原点 ( 0 , 0 ) (0, 0) (0,0) 是极小值点。****3. 牛顿法中的应用****4. 特征值与曲率方向****5. 机器学习中的实际意义**一、定义与…