最近看论文看到了图卷积神经网络的内容,之前整理过图神经网络的内容,这里再补充一下,方便以后查阅。
图卷积神经网络(Graph Convolutional Network, GCN)
- 图卷积神经网络
- 1. 什么是图卷积神经网络(GCN)?
- 2. GCN的原理
- 2.1 图的表示
- 2.2 谱图卷积
- 2.3 GCN的层级传播规则
- 2.4 消息传递框架
- 3. GCN的结构
- 4. GCN的应用
- 5. GCN的优点与局限性
- 优点:
- 局限性:
- 6. GCN代码示例
- 代码说明:
- 7. 如何扩展GCN?
- 8. 总结
- 谱图理论(Spectral Graph Theory)
- 1. 背景:为什么需要谱图理论?
- 2. 什么是谱图理论?
- 3. 图的拉普拉斯矩阵
- 3.1 定义
- 3.2 归一化拉普拉斯矩阵
- 3.3 拉普拉斯矩阵的性质
- 4. 通过拉普拉斯矩阵定义卷积操作
- 4.1 图傅里叶变换
- 4.2 谱卷积
- 4.3 GCN的简化
- 5. 为什么用拉普拉斯矩阵定义卷积?
- 6. 局限性与改进
- 图注意力网络(GAT)
- 1. 背景:为什么需要GAT?
- 2. GAT的原理
- 3. GAT的数学公式
- 3.1 注意力系数
- 3.2 归一化注意力系数
- 3.3 加权聚合
- 3.4 多头注意力
- 3.5 完整传播规则
- 4. GAT的结构
- 5. GAT与GCN的对比
- 6. GAT的优势
- 7. GAT的局限性
- 8. GAT的应用
- 9. GAT的实现与代码
- 代码说明:
- 10. GAT的扩展与改进
- 11. 总结
先验知识:图神经网络 GNN
图卷积神经网络
1. 什么是图卷积神经网络(GCN)?
GCN是卷积神经网络(CNN)的扩展,适用于非欧几里得空间的数据(如图结构数据)。传统CNN处理规则的网格数据(如图像或时间序列),而GCN处理节点和边构成的图结构数据。图由节点(vertices)和边(edges)组成,节点表示实体,边表示实体间的关系。
GCN的核心思想是通过消息传递机制,利用图的拓扑结构,将节点特征与其邻居的特征聚合,从而学习节点的表示(embedding)。这些表示可用于节点分类、链接预测或图分类等任务。
2. GCN的原理
GCN基于谱图理论(Spectral Graph Theory)和消息传递框架。以下是其核心原理:
2.1 图的表示
一个图 G = ( V , E ) G = (V, E) G=(V,E) 由以下部分组成:
- 节点集 V V V,节点数为 n n n。
- 边集 E E E,表示节点之间的连接。
- 邻接矩阵 A A A,大小为 n × n n \times n n×n,其中 A i j = 1 A_{ij} = 1 Aij=1 表示节点 i i i 和 j j j 之间有边,否则为 0 0 0。
- 节点特征矩阵 X X X,大小为 n × d n \times d n×d,其中每行是节点 i i i 的 d d d 维特征向量。
- 度矩阵 D D D,对角矩阵,其中 D i i = ∑ j A i j D_{ii} = \sum_j A_{ij} Dii=∑jAij 表示节点 i i i 的度。
2.2 谱图卷积
GCN最初基于谱图理论,通过图的拉普拉斯矩阵定义卷积操作。图拉普拉斯矩阵定义为:
L = D − A L = D - A L=D−A
归一化拉普拉斯矩阵为:
L n o r m = I − D − 1 / 2 A D − 1 / 2 L_{norm} = I - D^{-1/2} A D^{-1/2} Lnorm=I−D−1/2AD−1/2
其中 I I I 是单位矩阵, D − 1 / 2 D^{-1/2} D−1/2 是度矩阵的对角元素的倒数平方根。
谱图卷积通过拉普拉斯矩阵的特征分解,将卷积操作定义在图的频域上。然而,计算特征分解的复杂度较高( O ( n 3 ) O(n^3) O(n3)),因此实际中常用近似方法。
2.3 GCN的层级传播规则
现代GCN(如Kipf & Welling, 2017)使用简化的消息传递机制。一层GCN的传播规则为:
H ( l + 1 ) = σ ( D ~ − 1 / 2 A ~ D ~ − 1 / 2 H ( l ) W ( l ) ) H^{(l+1)} = \sigma \left( \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} H^{(l)} W^{(l)} \right) H(l+1)=σ(D~−1/2A~D~−1/2H(l)W(l))
其中:
- H ( l ) H^{(l)} H(l):第 l l l 层的节点特征矩阵, H ( 0 ) = X H^{(0)} = X H(0)=X。
- A ~ = A + I \tilde{A} = A + I A~=A+I:添加自环的邻接矩阵(每个节点与自身连接)。
- D ~ \tilde{D} D~: A ~ \tilde{A} A~ 对应的度矩阵, D ~ i i = ∑ j A ~ i j \tilde{D}_{ii} = \sum_j \tilde{A}_{ij} D~ii=∑jA~ij。
- W ( l ) W^{(l)} W(l):第 l l l 层的可学习权重矩阵。
- σ \sigma σ:激活函数(如ReLU)。
- D ~ − 1 / 2 A ~ D ~ − 1 / 2 \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} D~−1/2A~D~−1/2:归一化的邻接矩阵,用于平衡不同度节点的影响。
直观解释:
- 每层GCN聚合节点的邻居特征(包括自身),通过 A ~ \tilde{A} A~ 实现。
- 归一化( D ~ − 1 / 2 \tilde{D}^{-1/2} D~−1/2) 防止高阶节点主导聚合。
- 权重矩阵 W ( l ) W^{(l)} W(l) 进行特征变换,激活函数引入非线性。
2.4 消息传递框架
GCN可以看作消息传递神经网络(Message Passing Neural Network, MPNN)的一种:
- 聚合:收集邻居节点的特征(通过 A ~ \tilde{A} A~)。
- 更新:结合自身特征和聚合特征,更新节点表示(通过 W ( l ) W^{(l)} W(l) 和 σ \sigma σ)。
3. GCN的结构
一个典型的GCN模型包含以下部分:
- 输入层:接受图的邻接矩阵 A A A 和节点特征矩阵 X X X。
- 多层GCN:堆叠若干GCN层,每层执行特征聚合和变换。
- 输出层:
- 对于节点分类,输出每个节点的类别概率(通过Softmax)。
- 对于图分类,需要池化层(如全局平均池化)将节点特征汇总为图特征。
- 损失函数:
- 节点分类:交叉熵损失。
- 图分类:交叉熵或回归损失,取决于任务。
4. GCN的应用
GCN在许多领域有广泛应用:
- 社交网络:
- 节点分类:预测用户兴趣或社区归属。
- 链接预测:推荐好友或合作关系。
- 推荐系统:
- 使用用户-物品交互图,预测用户偏好。
- 化学分子分析:
- 图表示分子结构,预测分子性质(如毒性或溶解度)。
- 知识图谱:
- 实体分类或关系预测。
- 生物信息学:
- 分析蛋白质相互作用网络。
5. GCN的优点与局限性
优点:
- 适应图结构:能有效处理非规则的图数据。
- 局部性:通过邻居聚合,捕获局部拓扑信息。
- 可扩展性:可以堆叠多层,学习复杂模式。
局限性:
- 过平滑问题:
- 堆叠过多GCN层会导致节点特征趋于相似,丢失区分度。
- 固定拓扑:
- GCN依赖静态图结构,无法直接处理动态图。
- 计算复杂度:
- 对于大规模图,矩阵运算(如 A ~ H \tilde{A} H A~H)可能耗时。
- 边信息:
- 基本GCN不考虑边的权重或类型,后续变体(如GAT)改进此问题。
6. GCN代码示例
以下是一个基于PyTorch Geometric的GCN实现,用于节点分类任务。我们使用Cora数据集(一个常用的学术引用网络数据集),其中节点是论文,边是引用关系,目标是预测论文的类别。
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv# 加载Cora数据集
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]# 定义GCN模型
class GCN(torch.nn.Module):def __init__(self, in_channels, hidden_channels, out_channels):super(GCN, self).__init__()self.conv1 = GCNConv(in_channels, hidden_channels)self.conv2 = GCNConv(hidden_channels, out_channels)def forward(self, x, edge_index):# 第一层GCNx = self.conv1(x, edge_index)x = F.relu(x)x = F.dropout(x, p=0.5, training=self.training)# 第二层GCNx = self.conv2(x, edge_index)return F.log_softmax(x, dim=1)# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)# 初始化模型
model = GCN(in_channels=dataset.num_features,hidden_channels=16,out_channels=dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)# 训练模型
def train():model.train()optimizer.zero_grad()out = model(data.x, data.edge_index)loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()return loss.item()# 测试模型
def test():model.eval()out = model(data.x, data.edge_index)pred = out.argmax(dim=1)acc = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()return acc# 训练循环
for epoch in range(200):loss = train()if epoch % 10 == 0:acc = test()print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test Acc: {acc:.4f}')# 最终测试
final_acc = test()
print(f'Final Test Accuracy: {final_acc:.4f}')
代码说明:
- 数据集:Cora数据集包含2708个节点(论文),每个节点有1433维特征(词袋表示),7个类别,图有10556条边。
- 模型:两层GCN,第一层将特征从1433维降到16维,第二层输出7维(对应类别)。
- 训练:使用Adam优化器,交叉熵损失,仅对训练掩码(
train_mask
)的节点计算损失。 - 测试:在测试掩码(
test_mask
)上计算分类准确率。 - 依赖:需要安装
torch
和torch_geometric
:pip install torch torch-geometric
7. 如何扩展GCN?
GCN是图神经网络(GNN)的基础,许多改进模型基于GCN:
- 图注意力网络(GAT):引入注意力机制,动态分配邻居权重。
- GraphSAGE:通过采样邻居,适应大规模图。
- JK-Net:通过跳跃连接缓解过平滑问题。
- APPNP:结合个性化PageRank,增强传播效果。
8. 总结
GCN是一种强大的图神经网络,通过消息传递机制聚合邻居特征,学习图中节点的表示。其基于谱图理论,结构简单,适合节点分类、链接预测等任务。尽管GCN在许多领域表现优异,但过平滑和计算复杂度问题需要通过变体或优化解决。
谱图理论(Spectral Graph Theory)
1. 背景:为什么需要谱图理论?
在传统卷积神经网络(CNN)中,卷积操作适用于规则的网格数据(如图像的像素网格),通过滑动窗口提取局部特征。然而,图结构数据(如社交网络、分子结构)是非规则的,非欧几里得空间的数据,节点之间的连接(边)没有固定模式。因此,直接应用传统卷积不可行。
谱图理论提供了一种数学框架,通过分析图的拓扑结构(邻接关系),将图上的操作(如卷积)定义在频域(类似于傅里叶变换)。GCN的早期工作(例如Bruna等人,2013)利用谱图理论,将图上的卷积定义为基于图拉普拉斯矩阵的操作。
2. 什么是谱图理论?
谱图理论是图论的一个分支,研究图的性质通过其矩阵表示(如邻接矩阵或拉普拉斯矩阵)的特征值(eigenvalues)和特征向量(eigenvectors)。这些特征值和特征向量描述了图的拓扑结构,例如连通性、聚类特性等。
在GCN中,谱图理论的核心思想是将图上的信号(节点特征)投影到图的频域(由拉普拉斯矩阵的特征向量定义),进行类似傅里叶变换的操作,再转换回空间域。这种方法允许我们在图上定义卷积,类似于图像上的卷积。
3. 图的拉普拉斯矩阵
拉普拉斯矩阵(Laplacian Matrix)是图的矩阵表示,用于捕捉图的拓扑结构。以下是其定义和性质:
3.1 定义
对于一个无向图 G = ( V , E ) G = (V, E) G=(V,E),有 n n n 个节点,拉普拉斯矩阵 L L L 定义为:
L = D − A L = D - A L=D−A
- A A A:邻接矩阵,大小 n × n n \times n n×n,其中 A i j = 1 A_{ij} = 1 Aij=1 如果节点 i i i 和 j j j 之间有边,否则为 0 0 0。
- D D D:度矩阵,对角矩阵,大小 n × n n \times n n×n,其中 D i i = ∑ j A i j D_{ii} = \sum_j A_{ij} Dii=∑jAij 表示节点 i i i 的度(连接的边数),非对角元素为 0 0 0。
例如,对于一个简单图:
- 邻接矩阵 A A A:
A = [ 0 1 1 1 1 0 0 0 1 0 0 1 1 0 1 0 ] A = \begin{bmatrix} 0 & 1 & 1 & 1 \\ 1 & 0 & 0 & 0 \\ 1 & 0 & 0 & 1 \\ 1 & 0 & 1 & 0 \end{bmatrix} A= 0111100010011010 - 度矩阵 D D D:
D = [ 3 0 0 0 0 1 0 0 0 0 2 0 0 0 0 2 ] D = \begin{bmatrix} 3 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 2 & 0 \\ 0 & 0 & 0 & 2 \end{bmatrix} D= 3000010000200002 - 拉普拉斯矩阵 L L L:
L = D − A = [ 3 − 1 − 1 − 1 − 1 1 0 0 − 1 0 2 − 1 − 1 0 − 1 2 ] L = D - A = \begin{bmatrix} 3 & -1 & -1 & -1 \\ -1 & 1 & 0 & 0 \\ -1 & 0 & 2 & -1 \\ -1 & 0 & -1 & 2 \end{bmatrix} L=D−A= 3−1−1−1−1100−102−1−10−12
3.2 归一化拉普拉斯矩阵
为了平衡不同度节点的影响,常用归一化拉普拉斯矩阵:
L n o r m = I − D − 1 / 2 A D − 1 / 2 L_{norm} = I - D^{-1/2} A D^{-1/2} Lnorm=I−D−1/2AD−1/2
- D − 1 / 2 D^{-1/2} D−1/2:对角矩阵,其对角元素为 D i i − 1 / 2 = 1 / 度 i D_{ii}^{-1/2} = 1 / \sqrt{\text{度}_i} Dii−1/2=1/度i。
- I I I:单位矩阵。
归一化拉普拉斯矩阵的特征值在 [ 0 , 2 ] [0, 2] [0,2] 范围内,适合数值计算。
3.3 拉普拉斯矩阵的性质
- 对称性:对于无向图, L L L 是对称矩阵,因此有实特征值和正交特征向量。
- 正半定性: L L L 的特征值非负,反映图的连通性(例如,特征值 0 0 0 的重数等于连通分量的个数)。
- 频域解释:拉普拉斯矩阵的特征向量形成图的“频域基”,类似傅里叶变换中的正弦和余弦函数。特征值表示“频率”,低频率对应平滑信号,高频率对应快速变化的信号。
4. 通过拉普拉斯矩阵定义卷积操作
在谱图理论中,图上的卷积操作通过拉普拉斯矩阵的特征分解定义,类似于信号处理中的傅里叶变换。以下是具体步骤:
4.1 图傅里叶变换
拉普拉斯矩阵 L L L 可以分解为:
L = U Λ U T L = U \Lambda U^T L=UΛUT
- U U U: 特征向量矩阵,列是 L L L 的特征向量 u 1 , u 2 , … , u n u_1, u_2, \ldots, u_n u1,u2,…,un,表示图的“频域基”。
- Λ \Lambda Λ: 对角矩阵,对角元素是特征值 λ 1 , λ 2 , … , λ n \lambda_1, \lambda_2, \ldots, \lambda_n λ1,λ2,…,λn,表示“频率”。
对于节点特征向量 x ∈ R n x \in \mathbb{R}^n x∈Rn(每个节点一个标量特征),其图傅里叶变换定义为:
x ^ = U T x \hat{x} = U^T x x^=UTx
- x ^ \hat{x} x^ 是频域中的系数,表示 x x x 在特征向量基上的投影。
- 逆变换为:
x = U x ^ x = U \hat{x} x=Ux^
4.2 谱卷积
在频域中,卷积等价于逐频率相乘。假设有一个滤波器(卷积核) g g g,其频域表示为 g ( Λ ) g(\Lambda) g(Λ)(对角矩阵,元素为 g ( λ i ) g(\lambda_i) g(λi))。图上的卷积定义为:
x ∗ g = U g ( Λ ) U T x x * g = U g(\Lambda) U^T x x∗g=Ug(Λ)UTx
- g ( Λ ) g(\Lambda) g(Λ): 滤波器在频域的响应,控制如何放大或抑制不同频率的信号。
- U T x U^T x UTx: 将信号 x x x 转换为频域。
- U g ( Λ ) U T x U g(\Lambda) U^T x Ug(Λ)UTx: 应用滤波器后转换回空间域。
因此,图卷积是通过拉普拉斯矩阵的特征分解,将节点特征在频域中与滤波器结合,再转换回节点特征。
4.3 GCN的简化
早期谱GCN(如Bruna等人)直接使用上述卷积,但计算 U U U 和 U T U^T UT 的复杂度为 O ( n 3 ) O(n^3) O(n3),对大图不可行。Kipf & Welling (2017) 提出了简化版本:
- 使用多项式滤波器:假设 g ( Λ ) g(\Lambda) g(Λ) 是拉普拉斯矩阵特征值的多项式,例如 g ( Λ ) = ∑ k θ k Λ k g(\Lambda) = \sum_k \theta_k \Lambda^k g(Λ)=∑kθkΛk。
- 近似为低阶多项式(如一阶),避免特征分解。
- 最终传播规则为:
H ( l + 1 ) = σ ( D ~ − 1 / 2 A ~ D ~ − 1 / 2 H ( l ) W ( l ) ) H^{(l+1)} = \sigma \left( \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2} H^{(l)} W^{(l)} \right) H(l+1)=σ(D~−1/2A~D~−1/2H(l)W(l))
其中 A ~ = A + I \tilde{A} = A + I A~=A+I, D ~ \tilde{D} D~ 是 A ~ \tilde{A} A~ 的度矩阵。这种形式避免了昂贵的矩阵分解,直接操作邻接矩阵。
5. 为什么用拉普拉斯矩阵定义卷积?
拉普拉斯矩阵在图卷积中有以下优势:
- 捕捉拓扑结构:拉普拉斯矩阵编码了图的连接性(邻接关系和节点度),适合定义局部聚合操作。
- 频域解释:通过特征分解,拉普拉斯矩阵提供了一种频域视角,类似图像上的傅里叶变换,便于定义卷积。
- 平滑性:拉普拉斯矩阵与图的平滑性相关(例如, x T L x x^T L x xTLx 测量信号 x x x 在图上的变化),卷积操作可以平滑节点特征,聚合邻居信息。
- 数学优雅:谱图理论提供了统一的框架,将图上的操作与传统信号处理连接起来。
6. 局限性与改进
基于谱图理论的GCN有以下局限性:
- 计算复杂度:特征分解对大图不可行( O ( n 3 ) O(n^3) O(n3))。
- 泛化性:谱方法依赖图的固定拉普拉斯矩阵,难以直接应用于不同结构的图。
- 局部性:谱卷积本质上是全局操作,可能忽略局部特征。
因此,后续工作(如GraphSAGE、GAT)转向空间域方法,直接在图的拓扑上定义卷积(如邻居聚合),避免谱分解,提高效率和灵活性。
图注意力网络(GAT)
图注意力网络(Graph Attention Network, GAT)是一种图神经网络(Graph Neural Network, GNN)的变体,由Veličković等人于2017年提出(论文:《Graph Attention Networks》)。它通过引入注意力机制(Attention Mechanism)改进了传统的图卷积神经网络(GCN),能够动态地为不同邻居节点分配权重,从而更好地捕捉图结构中的异质性关系
1. 背景:为什么需要GAT?
图神经网络(如GCN)通过聚合节点邻居的特征来学习节点表示,适用于处理图结构数据(如社交网络、分子结构)。然而,GCN存在以下局限性:
- 等权重聚合:GCN假设所有邻居对节点的贡献相同(通过归一化的邻接矩阵),无法区分邻居的重要性。例如,在社交网络中,某些好友的影响可能更大。
- 固定拓扑依赖:GCN的聚合权重完全由图的拓扑结构(邻接矩阵和节点度)决定,缺乏灵活性。
- 无法捕捉异质性:对于高度异质的图(节点或边的关系差异显著),GCN的表现可能受限。
GAT通过引入注意力机制解决了这些问题,允许模型动态学习每个邻居的贡献权重,类似于自然语言处理中的Transformer模型。这种机制使GAT能够聚焦于对任务更重要的邻居,提高表示能力和泛化性。
2. GAT的原理
GAT的核心思想是通过注意力机制为每个节点的邻居分配不同的权重,然后基于这些权重聚合邻居特征。其操作可以概括为以下步骤:
- 计算注意力系数:为每条边(或邻居对)计算一个注意力分数,表示邻居的重要性。
- 归一化注意力系数:使用Softmax将注意力分数归一化为权重。
- 加权聚合:根据归一化的注意力权重,聚合邻居的特征。
- 多头注意力:可选地使用多组注意力机制(类似Transformer),增强模型表达能力。
GAT仍然基于消息传递框架(Message Passing Neural Network, MPNN),但其聚合方式比GCN更灵活。
3. GAT的数学公式
假设有一个图 G = ( V , E ) G = (V, E) G=(V,E),包含 n n n 个节点,节点特征矩阵为 X ∈ R n × d X \in \mathbb{R}^{n \times d} X∈Rn×d,其中每个节点有 d d d 维特征。以下是GAT一层的主要数学表达。
3.1 注意力系数
对于节点 i i i 和其邻居 j ∈ N i j \in \mathcal{N}_i j∈Ni(包括节点 i i i 自身,若考虑自环),GAT首先将节点特征通过线性变换映射到新的特征空间:
h i = W x i , h j = W x j h_i = W x_i, \quad h_j = W x_j hi=Wxi,hj=Wxj
- W ∈ R d ′ × d W \in \mathbb{R}^{d' \times d} W∈Rd′×d:可学习的权重矩阵,将特征从 d d d 维映射到 d ′ d' d′ 维。
- x i , x j x_i, x_j xi,xj:节点 i i i 和 j j j 的输入特征。
- h i , h j h_i, h_j hi,hj:映射后的特征。
然后,计算节点 i i i 和 j j j 之间的注意力系数 e i j e_{ij} eij:
e i j = a ( W x i , W x j ) e_{ij} = a(W x_i, W x_j) eij=a(Wxi,Wxj)
- a ( ⋅ , ⋅ ) a(\cdot, \cdot) a(⋅,⋅):注意力函数,通常是一个前馈神经网络。例如,Veličković等人使用单层感知器:
e i j = LeakyReLU ( a T [ W x i ∥ W x j ] ) e_{ij} = \text{LeakyReLU} \left( a^T [W x_i \parallel W x_j] \right) eij=LeakyReLU(aT[Wxi∥Wxj])- a ∈ R 2 d ′ a \in \mathbb{R}^{2d'} a∈R2d′:可学习的注意力向量。
- [ W x i ∥ W x j ] [W x_i \parallel W x_j] [Wxi∥Wxj]:将 W x i W x_i Wxi 和 W x j W x_j Wxj 拼接,得到 2 d ′ 2d' 2d′ 维向量。
- LeakyReLU:激活函数,增加非线性。
3.2 归一化注意力系数
为了使注意力系数可比较(类似概率分布),对节点 i i i 的所有邻居 j ∈ N i j \in \mathcal{N}_i j∈Ni 应用Softmax归一化:
α i j = exp ( e i j ) ∑ k ∈ N i exp ( e i k ) \alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}_i} \exp(e_{ik})} αij=∑k∈Niexp(eik)exp(eij)
- α i j \alpha_{ij} αij:归一化的注意力权重,表示邻居 j j j 对节点 i i i 的相对重要性。
3.3 加权聚合
使用归一化的注意力权重,聚合邻居的特征,更新节点 i i i 的表示:
h i ′ = σ ( ∑ j ∈ N i α i j W x j ) h_i' = \sigma \left( \sum_{j \in \mathcal{N}_i} \alpha_{ij} W x_j \right) hi′=σ j∈Ni∑αijWxj
- h i ′ h_i' hi′:节点 i i i 的更新特征。
- σ \sigma σ:激活函数(如ELU或ReLU)。
- W x j W x_j Wxj:邻居 j j j 的变换特征。
3.4 多头注意力
为了增强模型的表达能力和稳定性,GAT通常使用多头注意力(Multi-Head Attention)。运行 K K K 个独立的注意力机制,得到 K K K 组特征,然后拼接或平均:
- 中间层:拼接多头输出:
h i ′ = ∥ k = 1 K σ ( ∑ j ∈ N i α i j k W k x j ) h_i' = \parallel_{k=1}^K \sigma \left( \sum_{j \in \mathcal{N}_i} \alpha_{ij}^k W^k x_j \right) hi′=∥k=1Kσ j∈Ni∑αijkWkxj - W k W^k Wk:第 k k k 头的权重矩阵。
- α i j k \alpha_{ij}^k αijk:第 k k k 头的注意力权重。
- 输出维度为 K × d ′ K \times d' K×d′。
- 输出层:平均多头输出:
h i ′ = σ ( 1 K ∑ k = 1 K ∑ j ∈ N i α i j k W k x j ) h_i' = \sigma \left( \frac{1}{K} \sum_{k=1}^K \sum_{j \in \mathcal{N}_i} \alpha_{ij}^k W^k x_j \right) hi′=σ K1k=1∑Kj∈Ni∑αijkWkxj - 平均操作减少参数量,适合分类任务。
3.5 完整传播规则
一层GAT的传播规则可以总结为:
H ′ = σ ( ∑ j ∈ N i α i j W X ) H' = \sigma \left( \sum_{j \in \mathcal{N}_i} \alpha_{ij} W X \right) H′=σ j∈Ni∑αijWX
其中 H ′ H' H′ 是更新后的特征矩阵, X X X 是输入特征矩阵, α i j \alpha_{ij} αij 通过注意力机制计算。
4. GAT的结构
一个典型的GAT模型包含以下部分:
- 输入层:
- 输入:图的邻接矩阵(或边索引列表)和节点特征矩阵 X X X。
- 多层GAT:
- 堆叠若干GAT层,每层执行注意力机制和特征聚合。
- 中间层通常使用多头注意力(拼接),输出层可能使用单头或平均。
- 每层后可添加激活函数(如ELU)和Dropout(防止过拟合)。
- 输出层:
- 节点分类:通过Softmax输出每个节点的类别概率。
- 图分类:通过池化(如全局平均池化)将节点特征汇总为图特征。
- 损失函数:
- 节点分类:交叉熵损失。
- 图分类:交叉熵或回归损失,取决于任务。
5. GAT与GCN的对比
特性 | GCN | GAT |
---|---|---|
邻居聚合方式 | 等权重(基于归一化邻接矩阵) | 动态权重(通过注意力机制) |
权重计算 | 固定(由图结构决定) | 可学习(注意力系数动态调整) |
表达能力 | 较弱(无法区分邻居重要性) | 较强(捕捉异质性关系) |
计算复杂度 | 较低(矩阵乘法) | 较高(需计算注意力系数) |
过平滑问题 | 显著(多层后特征趋同) | 较轻(注意力机制保留差异) |
直观解释:
- GCN像“平均池化”,对所有邻居一视同仁。
- GAT像“加权池化”,根据任务动态选择重要邻居,类似“聪明地听意见”。
6. GAT的优势
- 动态权重分配:
- 注意力机制允许模型根据任务自动学习邻居的重要性。例如,在社交网络中,某些好友的影响可能更大。
- 捕捉异质性:
- GAT能处理节点或边关系差异显著的图,适合复杂网络。
- 多头注意力:
- 类似Transformer,增强模型表达能力,捕捉多种关系模式。
- 可解释性:
- 注意力系数 α i j \alpha_{ij} αij 可视化,揭示哪些邻居对预测更重要。
- 缓解过平滑:
- 相比GCN,GAT通过选择性聚合减少多层后特征趋同的问题。
7. GAT的局限性
- 计算复杂度:
- 计算注意力系数需要为每条边执行操作,复杂度为 O ( ∣ E ∣ ⋅ d ) O(|E| \cdot d) O(∣E∣⋅d),对稠密图或大图计算成本高。
- 多头注意力进一步增加计算量。
- 内存需求:
- 存储注意力系数和多头特征需要更多内存。
- 稳定性问题:
- 注意力机制可能导致训练不稳定,尤其在深层网络中,需小心调整超参数(如Dropout率、学习率)。
- 边信息限制:
- 基本GAT不直接利用边特征(如权重或类型),需扩展模型(如EGAT)。
- 过拟合风险:
- 在小图或稀疏图上,注意力机制可能过拟合,需正则化(如Dropout)。
8. GAT的应用
GAT在许多图结构数据的任务中表现出色,包括:
- 社交网络:
- 节点分类:预测用户兴趣、社区归属。
- 链接预测:推荐好友或合作关系。
- 推荐系统:
- 使用用户-物品交互图,预测用户偏好。
- 化学分子分析:
- 图表示分子结构,预测分子性质(如毒性、溶解度)。
- 知识图谱:
- 实体分类或关系预测。
- 生物信息学:
- 分析蛋白质相互作用网络,预测蛋白质功能。
- 交通网络:
- 预测交通流量或路径优化。
9. GAT的实现与代码
以下是一个基于PyTorch Geometric的GAT实现,用于节点分类任务。我们使用Cora数据集(一个常用的学术引用网络数据集),其中节点是论文,边是引用关系,目标是预测论文的类别。
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GATConv# 加载Cora数据集
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]# 定义GAT模型
class GAT(torch.nn.Module):def __init__(self, in_channels, hidden_channels, out_channels, heads=8):super(GAT, self).__init__()# 第一层GAT,使用多头注意力self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=0.6)# 第二层GAT,输出类别self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=0.6)def forward(self, x, edge_index):# 第一层GATx = self.conv1(x, edge_index)x = F.elu(x)x = F.dropout(x, p=0.6, training=self.training)# 第二层GATx = self.conv2(x, edge_index)return F.log_softmax(x, dim=1)# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)# 初始化模型
model = GAT(in_channels=dataset.num_features,hidden_channels=8,out_channels=dataset.num_classes,heads=8).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)# 训练模型
def train():model.train()optimizer.zero_grad()out = model(data.x, data.edge_index)loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])loss.backward()optimizer.step()return loss.item()# 测试模型
def test():model.eval()out = model(data.x, data.edge_index)pred = out.argmax(dim=1)acc = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()return acc# 训练循环
for epoch in range(200):loss = train()if epoch % 10 == 0:acc = test()print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test Acc: {acc:.4f}')# 最终测试
final_acc = test()
print(f'Final Test Accuracy: {final_acc:.4f}')
代码说明:
- 数据集:Cora数据集,包含2708个节点(论文),每个节点有1433维特征(词袋表示),7个类别,图有10556条边。
- 模型:两层GAT,第一层将特征从1433维降到8维(使用多头注意力),第二层输出7维(对应类别)。
- 注意力机制:GAT使用注意力系数动态加权邻居特征,增强模型对重要邻居的关注。
- 训练:使用Adam优化器,交叉熵损失,仅对训练掩码(
train_mask
)的节点计算损失。 - 测试:在测试掩码(
test_mask
)上计算分类准确率。 - 依赖:需要安装
torch
和torch_geometric
:pip install torch torch-geometric
10. GAT的扩展与改进
GAT是GNN领域的重要进展,许多后续工作在其基础上改进:
- GATv2(2021):改进注意力机制,解决原始GAT的表达能力瓶颈,增强性能。
- EGAT:引入边特征,扩展到加权或有类型的图。
- HGT(Heterogeneous Graph Transformer):结合GAT和Transformer,处理异构图。
- 采样优化:如GraphSAGE的采样策略,结合GAT,适应大规模图。
- 动态图:扩展GAT到时序图,处理动态拓扑。
11. 总结
图注意力网络(GAT)通过引入注意力机制,改进了GCN的局限性,能够动态学习邻居的重要性,增强对异质图的建模能力。其核心是计算注意力系数、归一化加权和多头注意力,数学上基于消息传递框架。GAT在社交网络、推荐系统、化学等领域的节点分类、链接预测等任务中表现优异,但计算复杂度和内存需求是其挑战。相比GCN,GAT更灵活、可解释,但在实际应用中需权衡性能和成本。