20250226-代码笔记05-class CVRP_Decoder

文章目录

  • 前言
  • 一、class CVRP_Decoder(nn.Module):__init__(self, **model_params)
    • 函数功能
    • 函数代码
  • 二、class CVRP_Decoder(nn.Module):set_kv(self, encoded_nodes)
    • 函数功能
    • 函数代码
  • 三、class CVRP_Decoder(nn.Module):set_q1(self, encoded_q1)
    • 函数功能
    • 函数代码
  • 四、class CVRP_Decoder(nn.Module):set_q2(self, encoded_q2)
    • 函数功能
    • 函数代码
  • 五、class CVRP_Decoder(nn.Module):forward(self, encoded_last_node, load, ninf_mask)
    • 函数功能
    • 函数代码
  • 附录
    • class CVRP_Decoder代码(全)


前言

class CVRP_DecoderCVRP_Model.py里的类。

/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPModel.py


一、class CVRP_Decoder(nn.Module):init(self, **model_params)

函数功能

init 方法是 CVRP_Decoder 类中的构造函数,主要功能是初始化该类所需的所有网络层、权重矩阵和参数。
该方法设置了用于多头注意力机制的权重、一个用于表示"遗憾"的参数、以及其他必要的操作用于计算注意力权重。

执行流程图链接
在这里插入图片描述

函数代码

    def __init__(self, **model_params):super().__init__()self.model_params = model_paramsembedding_dim = self.model_params['embedding_dim']head_num = self.model_params['head_num']qkv_dim = self.model_params['qkv_dim']# self.Wq_1 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.Wq_2 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.Wq_last = nn.Linear(embedding_dim+1, head_num * qkv_dim, bias=False)self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.regret_embedding = nn.Parameter(torch.Tensor(embedding_dim))self.regret_embedding.data.uniform_(-1, 1)self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)self.k = None  # saved key, for multi-head attentionself.v = None  # saved value, for multi-head_attentionself.single_head_key = None  # saved, for single-head attention# self.q1 = None  # saved q1, for multi-head attentionself.q2 = None  # saved q2, for multi-head attention

二、class CVRP_Decoder(nn.Module):set_kv(self, encoded_nodes)

函数功能

set_kv 方法的功能是将 encoded_nodes 中的节点嵌入转换为多头注意力机制所需的 键(K)值(V),并将它们分别保存为类的属性。
这个方法将输入的节点嵌入通过权重矩阵进行线性变换,得到键和值的表示,并为后续的多头注意力计算做好准备。
执行流程图链接
在这里插入图片描述

函数代码

    def set_kv(self, encoded_nodes):# encoded_nodes.shape: (batch, problem+1, embedding)head_num = self.model_params['head_num']self.k = reshape_by_heads(self.Wk(encoded_nodes), head_num=head_num)self.v = reshape_by_heads(self.Wv(encoded_nodes), head_num=head_num)# shape: (batch, head_num, problem+1, qkv_dim)self.single_head_key = encoded_nodes.transpose(1, 2)# shape: (batch, embedding, problem+1)

三、class CVRP_Decoder(nn.Module):set_q1(self, encoded_q1)

函数功能

set_q1 方法的主要功能是 计算查询(Q) 并将其转换为适用于多头注意力机制的形状。
该方法接受输入的查询张量 encoded_q1,通过线性层 self.Wq_1 映射到一个新的维度,并使用 reshape_by_heads 函数将其调整为适合多头注意力机制计算的形状。计算出的查询会被保存为类的属性 q1,供后续使用。

执行流程图链接
在这里插入图片描述

函数代码

    def set_q1(self, encoded_q1):# encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomohead_num = self.model_params['head_num']self.q1 = reshape_by_heads(self.Wq_1(encoded_q1), head_num=head_num)# shape: (batch, head_num, n, qkv_dim)

四、class CVRP_Decoder(nn.Module):set_q2(self, encoded_q2)

函数功能

set_q2 方法的主要功能是 计算查询(Q) 并将其转换为适用于多头注意力机制的形状。
该方法接收输入的查询张量 encoded_q2,通过线性层 self.Wq_2 映射到一个新的维度,并使用 reshape_by_heads 函数将其调整为适合多头注意力计算的形状。
执行流程图链接
在这里插入图片描述

函数代码

    def set_q2(self, encoded_q2):# encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomohead_num = self.model_params['head_num']self.q2 = reshape_by_heads(self.Wq_2(encoded_q2), head_num=head_num)# shape: (batch, head_num, n, qkv_dim)

五、class CVRP_Decoder(nn.Module):forward(self, encoded_last_node, load, ninf_mask)

函数功能

forward 方法是 CVRP_Decoder 类中的前向传播函数,主要功能是执行 多头自注意力机制 和 单头注意力计算,并最终输出每个可能节点的选择概率(probs)。
该方法通过多头注意力计算、前馈神经网络处理,以及概率计算来进行节点选择。

执行流程图链接
在这里插入图片描述

函数代码

    def forward(self, encoded_last_node, load, ninf_mask):# encoded_last_node.shape: (batch, pomo, embedding)# load.shape: (batch, pomo)# ninf_mask.shape: (batch, pomo, problem)head_num = self.model_params['head_num']#  Multi-Head Attention#######################################################input_cat = torch.cat((encoded_last_node, load[:, :, None]), dim=2)# shape = (batch, group, EMBEDDING_DIM+1)q_last = reshape_by_heads(self.Wq_last(input_cat), head_num=head_num)# shape: (batch, head_num, pomo, qkv_dim)# q = self.q1 + self.q2 + q_last# # shape: (batch, head_num, pomo, qkv_dim)# q = q_last# shape: (batch, head_num, pomo, qkv_dim)q = self.q2 + q_last# # shape: (batch, head_num, pomo, qkv_dim)out_concat = multi_head_attention(q, self.k, self.v, rank3_ninf_mask=ninf_mask)# shape: (batch, pomo, head_num*qkv_dim)mh_atten_out = self.multi_head_combine(out_concat)# shape: (batch, pomo, embedding)#  Single-Head Attention, for probability calculation#######################################################score = torch.matmul(mh_atten_out, self.single_head_key)# shape: (batch, pomo, problem)sqrt_embedding_dim = self.model_params['sqrt_embedding_dim']logit_clipping = self.model_params['logit_clipping']score_scaled = score / sqrt_embedding_dim# shape: (batch, pomo, problem)score_clipped = logit_clipping * torch.tanh(score_scaled)score_masked = score_clipped + ninf_maskprobs = F.softmax(score_masked, dim=2)# shape: (batch, pomo, problem)return probs

附录

class CVRP_Decoder代码(全)

class CVRP_Decoder(nn.Module):def __init__(self, **model_params):super().__init__()self.model_params = model_paramsembedding_dim = self.model_params['embedding_dim']head_num = self.model_params['head_num']qkv_dim = self.model_params['qkv_dim']# self.Wq_1 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.Wq_2 = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.Wq_last = nn.Linear(embedding_dim+1, head_num * qkv_dim, bias=False)self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)self.regret_embedding = nn.Parameter(torch.Tensor(embedding_dim))self.regret_embedding.data.uniform_(-1, 1)self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)self.k = None  # saved key, for multi-head attentionself.v = None  # saved value, for multi-head_attentionself.single_head_key = None  # saved, for single-head attention# self.q1 = None  # saved q1, for multi-head attentionself.q2 = None  # saved q2, for multi-head attentiondef set_kv(self, encoded_nodes):# encoded_nodes.shape: (batch, problem+1, embedding)head_num = self.model_params['head_num']self.k = reshape_by_heads(self.Wk(encoded_nodes), head_num=head_num)self.v = reshape_by_heads(self.Wv(encoded_nodes), head_num=head_num)# shape: (batch, head_num, problem+1, qkv_dim)self.single_head_key = encoded_nodes.transpose(1, 2)# shape: (batch, embedding, problem+1)def set_q1(self, encoded_q1):# encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomohead_num = self.model_params['head_num']self.q1 = reshape_by_heads(self.Wq_1(encoded_q1), head_num=head_num)# shape: (batch, head_num, n, qkv_dim)def set_q2(self, encoded_q2):# encoded_q.shape: (batch, n, embedding)  # n can be 1 or pomohead_num = self.model_params['head_num']self.q2 = reshape_by_heads(self.Wq_2(encoded_q2), head_num=head_num)# shape: (batch, head_num, n, qkv_dim)def forward(self, encoded_last_node, load, ninf_mask):# encoded_last_node.shape: (batch, pomo, embedding)# load.shape: (batch, pomo)# ninf_mask.shape: (batch, pomo, problem)head_num = self.model_params['head_num']#  Multi-Head Attention#######################################################input_cat = torch.cat((encoded_last_node, load[:, :, None]), dim=2)# shape = (batch, group, EMBEDDING_DIM+1)q_last = reshape_by_heads(self.Wq_last(input_cat), head_num=head_num)# shape: (batch, head_num, pomo, qkv_dim)# q = self.q1 + self.q2 + q_last# # shape: (batch, head_num, pomo, qkv_dim)# q = q_last# shape: (batch, head_num, pomo, qkv_dim)q = self.q2 + q_last# # shape: (batch, head_num, pomo, qkv_dim)out_concat = multi_head_attention(q, self.k, self.v, rank3_ninf_mask=ninf_mask)# shape: (batch, pomo, head_num*qkv_dim)mh_atten_out = self.multi_head_combine(out_concat)# shape: (batch, pomo, embedding)#  Single-Head Attention, for probability calculation#######################################################score = torch.matmul(mh_atten_out, self.single_head_key)# shape: (batch, pomo, problem)sqrt_embedding_dim = self.model_params['sqrt_embedding_dim']logit_clipping = self.model_params['logit_clipping']score_scaled = score / sqrt_embedding_dim# shape: (batch, pomo, problem)score_clipped = logit_clipping * torch.tanh(score_scaled)score_masked = score_clipped + ninf_maskprobs = F.softmax(score_masked, dim=2)# shape: (batch, pomo, problem)return probs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/70722.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

洛谷 P3628/SPOJ 15648 APIO2010 特别行动队 Commando

题意 你有一支由 n n n 名预备役士兵组成的部队,士兵从 1 1 1 到 n n n 编号,你要将他们拆分成若干特别行动队调入战场。出于默契的考虑,同一支特别行动队中队员的编号应该连续,即为形如 i , i 1 , ⋯ , i k i, i 1, \cdo…

PCL源码分析:曲面法向量采样

文章目录 一、简介二、源码分析三、实现效果参考资料一、简介 曲面法向量点云采样,整个过程如下所述: 1、空间划分:使用递归方法将点云划分为更小的区域, 每次划分选择一个维度(X、Y 或 Z),将点云分为两部分,直到划分区域内的点少于我们指定的数量,开始进行区域随机采…

Go语言--语法基础2--下载安装

2、下载安装 1、下载源码包: go1.18.4.linux-amd64.tar.gz。 官方地址:https://golang.google.cn/dl/ 云盘地址:链接: https://pan.baidu.com/s/1N2jrRHaPibvmmNFep3VYag 提 取码: zkc3 2、将下载的源码包解压…

lowagie(itext)老版本手绘PDF,包含页码、水印、图片、复选框、复杂行列合并等。

入口类:exportPdf ​ package xcsy.qms.webapi.service;import com.alibaba.fastjson.JSONArray; import com.alibaba.fastjson.JSONObject; import com.alibaba.nacos.common.utils.StringUtils; import com.ibm.icu.text.RuleBasedNumberFormat; import com.lowa…

3-2 WPS JS宏 工作簿的打开与保存(模板批量另存为工作)学习笔记

************************************************************************************************************** 点击进入 -我要自学网-国内领先的专业视频教程学习网站 *******************************************************************************************…

Ubuntu20.04之VNC的安装使用与常见问题

Ubuntu20.04之VNC的安装与使用 安装图形桌面选择安装gnome桌面选择安装xface桌面 VNC-Server安装配置开机自启 VNC Clientroot用户无法登入问题临时方案永久方案 安装图形桌面 Ubuntu20.04主流的图形桌面有gnome和xface两种,两种桌面的安装方式我都会写&#xff0c…

Day46 反转字符串

I. 编写一个函数,其作用是将输入的字符串反转过来。输入字符串以字符数组 s 的形式给出。 不要给另外的数组分配额外的空间,你必须原地修改输入数组、使用 O(1) 的额外空间解决这一问题。 class Solution {public void reverseString(char[] s) {int i …

用FileZilla Server 1.9.4给Windows Server 2025搭建FTP服务端

FileZilla Server 是一款免费的开源 FTP 和 FTPS 服务器软件,分为服务器版和客户端版。服务器版原本只支持Windows操作系统,比如笔者曾长期使用过0.9.60版,那时候就只支持Windows操作系统。当时我们生产环境对FTP稳定性要求较高,比…

【愚公系列】《Python网络爬虫从入门到精通》033-DataFrame的数据排序

标题详情作者简介愚公搬代码头衔华为云特约编辑,华为云云享专家,华为开发者专家,华为产品云测专家,CSDN博客专家,CSDN商业化专家,阿里云专家博主,阿里云签约作者,腾讯云优秀博主,腾讯云内容共创官,掘金优秀博主,亚马逊技领云博主,51CTO博客专家等。近期荣誉2022年度…

营销过程乌龟图模版

营销过程乌龟图模版 输入 公司现状产品服务客户问询客户期望电话、电脑系统品牌软件硬件材料 售前 - 沟通 - 确定需求 - 满足需求 - 售后 机料环 电话、电脑等设备软件硬件、系统品牌等工具材料 人 责任人协助者生产者客户 法 订单由谁评审控制程序营销过程控制程序顾客满意度…

Kubernetes (K8S) 高效使用技巧与实践指南

Kubernetes(K8S)作为容器编排领域的核心工具,其灵活性和复杂性并存。本文结合实战经验,从运维效率提升、生产环境避坑、核心功能应用等维度,总结高频使用技巧与最佳实践,分享如何快速掌握 K8S。 一、kubect…

Idea java项目结构介绍

一般来说,一个典型的 IntelliJ IDEA Java 项目具有特定的结构,以下是对其主要部分的介绍: 项目根目录 项目的最顶层目录,包含了整个项目的所有文件和文件夹,通常以项目名称命名。在这个目录下可以找到.idea文件夹、.g…

C++大整数类的设计与实现

1. 简介 我们知道现代的计算机大多数都是64位的,因此能处理最大整数为 2 64 − 1 2^{64}-1 264−1。那如果是超过了这个数怎么办呢,那就需要我们自己手动模拟数的加减乘除了。 2. 思路 我们可以用一个数组来存储大数,数组中的每一个位置表…

2024年第十五届蓝桥杯大赛软件赛省赛Python大学A组真题解析

文章目录 试题A: 拼正方形(本题总分:5 分)解析答案试题B: 召唤数学精灵(本题总分:5 分)解析答案试题C: 数字诗意解析答案试题A: 拼正方形(本题总分:5 分) 【问题描述】 小蓝正在玩拼图游戏,他有7385137888721 个2 2 的方块和10470245 个1 1 的方块,他需要从中挑出一些…

开源RAG主流框架有哪些?如何选型?

开源RAG主流框架有哪些?如何选型? 一、开源RAG框架全景图 (一)核心框架类型对比 类型典型工具技术特征适用场景传统RAGLangChain, Haystack线性流程(检索→生成)通用问答、知识库检索增强型RAGRAGFlow, AutoRAG支持重排序、多路召回优化高精度问答、复杂文档处理轻量级…

Java SE与Java EE

Java SE(Java 平台标准版) Java SE 是 Java 平台的核心,提供了 Java 语言的基础功能。它包含了 Java 开发工具包(JDK),其中有 Java 编译器(javac)、Java 虚拟机(JVM&…

【Java企业生态系统的演进】从单体J2EE到云原生微服务

Java企业生态系统的演进:从单体J2EE到云原生微服务 目录标题 Java企业生态系统的演进:从单体J2EE到云原生微服务摘要1. 引言2. 整体框架演进:从原始Java到Spring Cloud2.1 原始Java阶段(1995-1999)2.2 J2EE阶段&#x…

kicad中R树的使用

在 KiCad 中,使用 R树(R-tree)进行空间索引和加速查询通常不在用户层面直接操作,而是作为工具的一部分用于优化电路板设计的性能,尤其在布局、碰撞检测、设计规则检查(DRC)以及元件搜索等方面。…

org.springframework.boot不存在的其中一个解决办法

最近做项目的时候发现问题,改了几次pom.xml文件之后突然发现项目中的注解全部爆红。 可以尝试点击左上角的循环小图标,同步所有maven项目。 建议顺便检查一下Project Structure中的SDK和Language Level是否对应,否则可能报类似:“…

C语言实现通讯录项目

一、通讯录功能 实现一个可以存放100个人的信息的通讯录(这里采用静态版本),每个人的信息有姓名、性别、年龄、电话、地址等。 通讯录可以执行的操作有添加联系人信息、删除指定联系人、查找指定联系人信息、修改指定联系人信息、显示联系人信…