详细讲一下PYG 里面的torch_geometric.nn.conv.transformer_conv函数

1.首先先讲一下代码

这是官方给的代码:torch_geometric.nn.conv.transformer_conv — pytorch_geometric documentation

import math
import typing
from typing import Optional, Tuple, Unionimport torch
import torch.nn.functional as F
from torch import Tensorfrom torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import (Adj,NoneType,OptTensor,PairTensor,SparseTensor,
)
from torch_geometric.utils import softmaxif typing.TYPE_CHECKING:from typing import overload
else:from torch.jit import _overload_method as overload[docs]class TransformerConv(MessagePassing):r"""The graph transformer operator from the `"Masked Label Prediction:Unified Message Passing Model for Semi-Supervised Classification"<https://arxiv.org/abs/2009.03509>`_ paper... math::\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j},where the attention coefficients :math:`\alpha_{i,j}` are computed viamulti-head dot product attention:.. math::\alpha_{i,j} = \textrm{softmax} \left(\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)}{\sqrt{d}} \right)Args:in_channels (int or tuple): Size of each input sample, or :obj:`-1` toderive the size from the first input(s) to the forward method.A tuple corresponds to the sizes of source and targetdimensionalities.out_channels (int): Size of each output sample.heads (int, optional): Number of multi-head-attentions.(default: :obj:`1`)concat (bool, optional): If set to :obj:`False`, the multi-headattentions are averaged instead of concatenated.(default: :obj:`True`)beta (bool, optional): If set, will combine aggregation andskip information via.. math::\mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i +(1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)}\alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i}with :math:`\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top}[ \mathbf{W}_1 \mathbf{x}_i, \mathbf{m}_i, \mathbf{W}_1\mathbf{x}_i - \mathbf{m}_i ])` (default: :obj:`False`)dropout (float, optional): Dropout probability of the normalizedattention coefficients which exposes each node to a stochasticallysampled neighborhood during training. (default: :obj:`0`)edge_dim (int, optional): Edge feature dimensionality (in casethere are any). Edge features are added to the keys afterlinear transformation, that is, prior to computing theattention dot product. They are also added to final valuesafter the same linear transformation. The model is:.. math::\mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +\sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left(\mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij}\right),where the attention coefficients :math:`\alpha_{i,j}` are nowcomputed via:.. math::\alpha_{i,j} = \textrm{softmax} \left(\frac{(\mathbf{W}_3\mathbf{x}_i)^{\top}(\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})}{\sqrt{d}} \right)(default :obj:`None`)bias (bool, optional): If set to :obj:`False`, the layer will not learnan additive bias. (default: :obj:`True`)root_weight (bool, optional): If set to :obj:`False`, the layer willnot add the transformed root node features to the output and theoption  :attr:`beta` is set to :obj:`False`. (default: :obj:`True`)**kwargs (optional): Additional arguments of:class:`torch_geometric.nn.conv.MessagePassing`."""_alpha: OptTensordef __init__(self,in_channels: Union[int, Tuple[int, int]],out_channels: int,heads: int = 1,concat: bool = True,beta: bool = False,dropout: float = 0.,edge_dim: Optional[int] = None,bias: bool = True,root_weight: bool = True,**kwargs,):kwargs.setdefault('aggr', 'add')super().__init__(node_dim=0, **kwargs)self.in_channels = in_channelsself.out_channels = out_channelsself.heads = headsself.beta = beta and root_weightself.root_weight = root_weightself.concat = concatself.dropout = dropoutself.edge_dim = edge_dimself._alpha = Noneif isinstance(in_channels, int):in_channels = (in_channels, in_channels)self.lin_key = Linear(in_channels[0], heads * out_channels)self.lin_query = Linear(in_channels[1], heads * out_channels)self.lin_value = Linear(in_channels[0], heads * out_channels)if edge_dim is not None:self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)else:self.lin_edge = self.register_parameter('lin_edge', None)if concat:self.lin_skip = Linear(in_channels[1], heads * out_channels,bias=bias)if self.beta:self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False)else:self.lin_beta = self.register_parameter('lin_beta', None)else:self.lin_skip = Linear(in_channels[1], out_channels, bias=bias)if self.beta:self.lin_beta = Linear(3 * out_channels, 1, bias=False)else:self.lin_beta = self.register_parameter('lin_beta', None)self.reset_parameters()[docs]    def reset_parameters(self):super().reset_parameters()self.lin_key.reset_parameters()self.lin_query.reset_parameters()self.lin_value.reset_parameters()if self.edge_dim:self.lin_edge.reset_parameters()self.lin_skip.reset_parameters()if self.beta:self.lin_beta.reset_parameters()@overloaddef forward(self,x: Union[Tensor, PairTensor],edge_index: Adj,edge_attr: OptTensor = None,return_attention_weights: NoneType = None,) -> Tensor:pass@overloaddef forward(  # noqa: F811self,x: Union[Tensor, PairTensor],edge_index: Tensor,edge_attr: OptTensor = None,return_attention_weights: bool = None,) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:pass@overloaddef forward(  # noqa: F811self,x: Union[Tensor, PairTensor],edge_index: SparseTensor,edge_attr: OptTensor = None,return_attention_weights: bool = None,) -> Tuple[Tensor, SparseTensor]:pass[docs]    def forward(  # noqa: F811self,x: Union[Tensor, PairTensor],edge_index: Adj,edge_attr: OptTensor = None,return_attention_weights: Optional[bool] = None,) -> Union[Tensor,Tuple[Tensor, Tuple[Tensor, Tensor]],Tuple[Tensor, SparseTensor],]:r"""Runs the forward pass of the module.Args:x (torch.Tensor or (torch.Tensor, torch.Tensor)): The input nodefeatures.edge_index (torch.Tensor or SparseTensor): The edge indices.edge_attr (torch.Tensor, optional): The edge features.(default: :obj:`None`)return_attention_weights (bool, optional): If set to :obj:`True`,will additionally return the tuple:obj:`(edge_index, attention_weights)`, holding the computedattention weights for each edge. (default: :obj:`None`)"""H, C = self.heads, self.out_channelsif isinstance(x, Tensor):x = (x, x)query = self.lin_query(x[1]).view(-1, H, C)key = self.lin_key(x[0]).view(-1, H, C)value = self.lin_value(x[0]).view(-1, H, C)# propagate_type: (query: Tensor, key:Tensor, value: Tensor,#                  edge_attr: OptTensor)out = self.propagate(edge_index, query=query, key=key, value=value,edge_attr=edge_attr)alpha = self._alphaself._alpha = Noneif self.concat:out = out.view(-1, self.heads * self.out_channels)else:out = out.mean(dim=1)if self.root_weight:x_r = self.lin_skip(x[1])if self.lin_beta is not None:beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1))beta = beta.sigmoid()out = beta * x_r + (1 - beta) * outelse:out = out + x_rif isinstance(return_attention_weights, bool):assert alpha is not Noneif isinstance(edge_index, Tensor):return out, (edge_index, alpha)elif isinstance(edge_index, SparseTensor):return out, edge_index.set_value(alpha, layout='coo')else:return outdef message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,edge_attr: OptTensor, index: Tensor, ptr: OptTensor,size_i: Optional[int]) -> Tensor:if self.lin_edge is not None:assert edge_attr is not Noneedge_attr = self.lin_edge(edge_attr).view(-1, self.heads,self.out_channels)key_j = key_j + edge_attralpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels)alpha = softmax(alpha, index, ptr, size_i)self._alpha = alphaalpha = F.dropout(alpha, p=self.dropout, training=self.training)out = value_jif edge_attr is not None:out = out + edge_attrout = out * alpha.view(-1, self.heads, 1)return outdef __repr__(self) -> str:return (f'{self.__class__.__name__}({self.in_channels}, 'f'{self.out_channels}, heads={self.heads})')

2.详细解释一下

几个重要的参数

in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

out_channels (int): Size of each output sample.

heads (int, optional): Number of multi-head-attentions. (default: :obj:`1`)

怎么理解这几个参数?

 

  • in_channels 表示每个输入样本的大小。如果设置为整数,则表示所有输入样本的大小相同;如果设置为 -1,则表示输入样本的大小将从 forward 方法的第一个输入中推导出来;如果设置为元组,则表示输入样本的大小对应于源维度和目标维度的大小。

  • out_channels 表示每个输出样本的大小,即经过卷积操作后产生的特征向量的维度大小。

 

当使用 tg.nn.TransformerConv 时,可以通过以下方式理解 in_channelsout_channels

假设我们有一个图数据集,每个节点都有一个 10 维的特征向量表示。那么在这种情况下:

  • 如果我们想将每个节点的特征向量作为输入,然后使用 tg.nn.TransformerConv 进行卷积操作,那么 in_channels 应该设置为 10,表示每个输入样本的大小为 10。

  • 假设我们想将节点的特征向量转换为一个 16 维的特征向量,那么 out_channels 应该设置为 16,表示每个输出样本的大小为 16,即经过卷积操作后每个节点的特征向量将变为 16 维。

  • tg.nn.TransformerConv 中,heads 参数表示多头注意力的数量。举个例子,如果 heads 参数设置为 4,那么模型将学习 4 组注意力权重,每组权重都用于计算输入的不同子空间的注意力,然后将这些头的输出进行合并以产生最终的输出。

 举个整体的例子

我们有一个输入张量 x,它的形状是 (batch_size, seq_length, input_dim),其中:

  • batch_size 表示批量大小;
  • seq_length 表示序列长度;
  • input_dim 表示输入特征的维度。

现在假设我们使用了 tg.nn.TransformerConv,并设置 heads=2,那么模型将学习两组注意力权重,每组用于计算不同的注意力。输出张量的形状将取决于 out_channels 参数,我们假设 out_channels=64

import torch
import torch_geometric.nn as tg# 假设输入张量的形状是 (batch_size, seq_length, input_dim)
x = torch.randn(32, 10, 128)  # 32 个样本,每个样本有 10 个时间步,每个时间步有 128 个特征# 创建 TransformerConv 模型,设置 heads=2,out_channels=64
conv_layer = tg.nn.TransformerConv(in_channels=128, out_channels=64, heads=2)# 使用模型进行前向传播
output = conv_layer(x)print("输出张量的形状:", output.shape)

 2.1将特征映射到键值对中

在这里,通过线性变换层 Linear,输入特征被转换成了键(key)、查询(query)和数值(value)的表示形式,以便用于多头自注意力机制。

具体来说:

  • self.lin_key 用于将输入特征(in_channels[0])映射到键的表示形式。
  • self.lin_query 用于将输入特征(in_channels[1])映射到查询的表示形式。
  • self.lin_value 用于将输入特征(in_channels[0])映射到数值的表示形式。

 具体地,假设输入特征的维度是 (batch_size, num_nodes, in_channels),其中 batch_size 是批量大小,num_nodes 是节点数,in_channels 是输入特征的通道数。在映射到键的过程中,线性变换层的权重矩阵将是一个维度为 (in_channels, heads * out_channels) 的矩阵,其中 heads 是注意力头的数量,out_channels 是输出特征的通道数。因此,通过矩阵乘法运算,输入特征将被映射到一个新的特征空间,其维度为 (batch_size, num_nodes, heads, out_channels)。在这个新的特征空间中,每个节点的每个头都有一个键表示。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/834105.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用凌鲨建立软件研发技能学习小组

凌鲨(OpenLinkSaas)的团队功能除了提供论坛功能&#xff0c;还能记录团队成员的成长记录。 使用方法 打开团队功能 团队功能在默认情况下是关闭的&#xff0c;你可以在登录后打开团队功能开关。 创建学习团队 日报/周报/个人目标一般是企业团队需要&#xff0c;建议关闭。 …

Shell生成支持x264的ffmpeg安卓全平台so

安卓 FFmpeg系列 第一章 Ubuntu生成ffmpeg安卓全平台so 第二章 Windows生成ffmpeg安卓全平台so 第三章 生成支持x264的ffmpeg安卓全平台so&#xff08;本章&#xff09; 文章目录 安卓 FFmpeg系列前言一、实现步骤1、下载x264源码2、交叉编译生成.a3、加入x264配置4、编译ffmp…

Redis 哨兵机制

文章目录 哨兵机制概念相关知识铺垫主从复制缺陷哨兵工作流程选举具体流程理解注意事项 哨兵机制概念 先抽象的理解&#xff0c;哨兵就像是监工&#xff0c;节点不干活了&#xff0c;就要有行动了。 Redis 的主从复制模式下&#xff0c;⼀旦主节点由于故障不能提供服务&#…

MVC WebAPI

创建项目 创建api控制器 》》》 web api 控制器要继承 ApiController 》》》 数据会自动装配 及自动绑定 》》》清除xml返回格式 //清除XML返回格式 GlobalConfiguration.Configuration.Formatters.XmlFormatter.SupportedMediaTypes.Clear(); 》》》跨越问题

【汇编语言小练习输入两个数字然后输出它们的和】

这个程序是用汇编语言编写一个简单的程序&#xff0c;它将从键盘输入两个数字&#xff0c;然后输出它们的和。 .MODEL SMALL .STACK 100H.DATAINPUT_MSG1 DB Enter the first number: $INPUT_MSG2 DB 13, 10, Enter the second number: $RESULT_MSG DB 13, 10, The sum is: $N…

2024服贸会,参展企业媒体宣传报道攻略

传媒如春雨&#xff0c;润物细无声&#xff0c;大家好&#xff0c;我是51媒体网胡老师。 2024年中国国际服务贸易交易会&#xff08;简称“服贸会”&#xff09;是一个重要的国际贸易平台&#xff0c;对于参展企业来说&#xff0c;有效的媒体宣传报道对于提升品牌知名度、扩大…

FPGA第二篇,FPGA与CPU GPU APU DSP NPU TPU 之间的关系与区别

简介&#xff1a;首先&#xff0c;FPGA与CPU GPU APU NPU TPU DSP这些不同类型的处理器&#xff0c;可以被统称为"处理器"或者"加速器"。它们在计算机硬件系统中承担着核心的计算和处理任务&#xff0c;可以说是系统的"大脑"和"加速引擎&qu…

社区奶柜:小本创业,大有可为

社区奶柜&#xff1a;小本创业&#xff0c;大有可为 在快节奏的现代生活中&#xff0c;人们对健康、便捷生活方式的追求日益增长。社区奶柜加盟项目&#xff0c;正是应运而生&#xff0c;它不仅满足了居民对于新鲜、营养乳制品的日常需求&#xff0c;也为寻求创业机会的您铺设…

excel中图片url转为jpg

步骤 打开Excel&#xff0c;并按下 Alt F11 打开VBA编辑器。在VBA编辑器中&#xff0c;插入一个新的模块&#xff08;右键点击项目资源管理器中的模块 -> 插入 -> 模块&#xff09;。在新模块的代码窗口中&#xff0c;复制并粘贴以下示例代码。根据需要修改代码中的变量…

力扣2105---给植物浇水II(Java、模拟、双指针)

题目描述&#xff1a; Alice 和 Bob 打算给花园里的 n 株植物浇水。植物排成一行&#xff0c;从左到右进行标记&#xff0c;编号从 0 到 n - 1 。其中&#xff0c;第 i 株植物的位置是 x i 。 每一株植物都需要浇特定量的水。Alice 和 Bob 每人有一个水罐&#xff0c;最初是…

AS01_BAPI:BAPI_FIXEDASSET_OVRTAKE_CREATE_创建资产主数据

AS01_创建资产主数据 一、功能介绍 使用事务码AS01创建资产主数据 二、程序代码 程序代码&#xff1a; REPORT zfir002.TYPE-POOLS: slis, icon. TYPES: BEGIN OF ty_file,zanln TYPE anla-anln1, "资产编号anln2 TYPE anla-anln2, "资产次级编号bukrs …

FastAPI vs Flask: 选择最适合您的 Python Web 框架

文章目录 1. 简介2. 安装和设置3. 路由和视图4. 自动文档生成5. 数据验证和序列化6. 性能和异步支持结论 在 Python Web 开发领域&#xff0c;FastAPI 和 Flask 是两个备受欢迎的选择。它们都提供了强大的工具和功能&#xff0c;但是在某些方面有所不同。本文将比较 FastAPI 和…

Ubuntu 下串口工具:Minicom、CuteCom 和 Screen

在 Ubuntu 中&#xff0c;对于串口通信工具的选择&#xff0c;虽然没有一个绝对的 “最好用” 的排名&#xff0c;但根据用户反馈和工具的流行程度&#xff0c;Minicom、CuteCom 和 Screen 这三个工具通常被认为是较为受欢迎和实用的。 一、简介&#xff1a; Minicom&#xff…

【论文笔记】KAN: Kolmogorov-Arnold Networks 全新神经网络架构KAN,MLP的潜在替代者

KAN: Kolmogorov-Arnold Networks code&#xff1a;https://github.com/KindXiaoming/pykan Background ​ 多层感知机&#xff08;MLP&#xff09;是机器学习中拟合非线性函数的默认模型&#xff0c;在众多深度学习模型中被广泛的应用。但MLP存在很多明显的缺点&#xff1a;…

跨协议通讯无缝对接:Modbus-BACnet楼宇智能转换器深度解析

在现代化的建筑群里&#xff0c;智能楼宇管理系统如同神经系统&#xff0c;协调着各设备的运行。某大型商业综合体&#xff0c;集购物中心、办公区、酒店于一体&#xff0c;面对着来自不同供应商的设备&#xff0c;如何实现统一管理和高效通讯成了首要挑战。特别是其内部既有采…

ADC模-数转换原理与实现

1. 今日摸鱼计划 今天来学习一下ADC的原理&#xff0c;然后把ADC给实现 ADC芯片:ADC128S102 视频&#xff1a; 18A_基于SPI接口的ADC芯片功能和接口时序介绍_哔哩哔哩_bilibili 18B_使用线性序列机思路分析SPI接口的ADC芯片接口时序_哔哩哔哩_bilibili 18C_基于线性序列机的S…

neo4j-5.11.0安装APOC插件or配置允许使用过程的权限

在已经安装好neo4j和jdk的情况下安装apoc组件&#xff0c;之前使用neo4j-community-4.4.30&#xff0c;可以找到配置apoc-4.4.0.22-all.jar&#xff0c;但是高版本neo4j对应没有apoc-X.X.X-all.jar。解决如下所示&#xff1a; 1.安装好JDK与neo4j 已经安装对应版本的JDK 17.0…

MySQL数据库及数据表的创建

1.创建一个名叫 db_classes 的数据库&#xff1a; 创建一个叫 db_classes 的数据库MySQL命令&#xff1a; create database db_classes; 运行效果&#xff1a; 创建数据库后查看该数据库基本信息MySQL命令&#xff1a; show create database db_classes; 运行效果&#xff…

力扣刷题--数组--第三天

今天再做两道二分查找的题目&#xff0c;关于二分查找的知识可看我前两篇博客。话不多说&#xff0c;直接开干&#xff01; 题目1&#xff1a;69.x 的平方根 题目详情&#xff1a;   给你一个非负整数 x &#xff0c;计算并返回 x 的 算术平方根 。由于返回类型是整数&#…

VScode通过ssh远程连接服务器被拒绝:permission denied, please try again

使用场景&#xff1a; 使用windows系统下的vscode远程连接服务器的linux系统&#xff0c;终端提示permission denied, please try again,但是使用cmd是可以远程登录的。 解决办法&#xff1a; 前提条件windows端的vscode安装了ssh远程连接的相关插件Remote - SSH&#xff0c;…