图神经网络:GAT图注意力网络原理和源码解读(tensorflow)
图神经网络:GAT图注意力网络原理和源码解读(tensorflow)
图神经网络(GNN)在处理图结构数据时展现出强大的能力,而GAT(Graph Attention Network)作为其中的重要分支,通过引入注意力机制进一步提升了模型的表征能力。本文将从GNN、GCN到GAT的演变过程进行详细讲解,并深入解析注意力机制在图结构中的应用原理。
GNN到GAT的演变
GNN(Graph Neural Network):学习邻居节点聚合到中心节点的方式,传统的GNN采用求和或求平均的方式,每个邻居节点的权重相等。
GCN(Graph Convolutional Network):对邻居聚合方式进行改进,采用邻接矩阵的对称归一化,考虑节点度的大小进行权重调整,但仍然是基于规则的权重分配。
GAT(Graph Attention Network):认为不同邻居节点对中心节点的影响是不同的,通过注意力机制自动学习权重参数,提升表征能力。GAT使用邻居节点和中心节点的特征属性来确定权重,确保所有邻居节点的权重之和为1。
注意力机制原理
在注意力机制中,Source代表需要处理的信息,Query代表某种条件或先验信息,Attention Value是通过先验信息和Attention机制从Source中提取的信息。Source中的信息通过key-value对的形式表达,可以将key类比为信息的摘要,value类比为信息的全部内容。
从公式来看,注意力机制就是先计算出前提条件和每个要接受的信息的摘要部分的相关程度,以相关程度为权重再学习要接受的每个全部信息,最后每条信息加权求和得到结果。
GNN中的注意力机制
类比Key-Value注意力机制,在图结构中,中心节点就是Query,所有邻居节点的信息就是Source,Attention Value就是中心节点经过聚合之后的特征向量,Key和Value相同,就是邻居节点的特征向量。目标就是针对中心节点(Query)学习邻居节点(Source)的权重,再加权求和汇总到中心节点上形成新的特征向量表达(Attention Value)。
GAT示意图公式:
这个公式先简单理解一下,这个图旋转90度就是个逻辑回归一样的全连接。hi和hj代表的节点的特征向量或者当下的特征表达,i为中心节点,j为邻居节点,目标是计算eij两个节点之间的权重,Whi代表使用一个模型自己学习的共享的W向量来对原始特征向量做维度转换,比如原始是(512, 128),W为(128, 64),最终转化为(512, 64),i和j都转化之后拼接,再用一个全连接作为相似计算函数,激活函数为LeakyRelu,此时全连接之后产出一个值,所有的全连接值再做softmax归一化得到最终ij节点的权重值。为了保留图结构的连接关系,注意力只再中心节点和邻居节点之间计算,且一个注意力机制的a,W是共享的。
多头注意力机制
在计算出节点间的attention权重值之后新的中心节点表达如下,每个邻居的特征向量点乘一个维度转换向量参数之后,再乘上attention的权重,最后加权求和套一个激活函数输出下一层中心节点的特征表达
为了防止Attention过拟合,引入多个Attention,引入多套W和a,使得模型更加稳定
多头注意力机制GNN图示:
如图所示h1有h2,h3,h4,h5,h6这几个邻居,每个都算了三套三种颜色的Attention权重,最后拼接/平均得出下一层的h1表达
多头注意力机制GNN公式:
公式里面K就是几套注意力机制,
||
代表向量拼接,下面一种是求平均
源码解读
root@ubuntu:/home/gp/git/GAT# python execute_cora.py
Dataset: cora
----- Opt. hyperparams -----
lr: 0.005
l2_coef: 0.0005
----- Archi. hyperparams -----
nb. layers: 1
nb. units per layer: [8]
nb. attention heads: [8, 1]
residual: False
nonlinearity: <function elu at 0x7fd6bfeaf950>
model: <class 'models.gat.GAT'>
(2708, 2708)
(2708, 1433)
...
Training: loss = 1.13271, acc = 0.60000 | Val: loss = 1.00691, acc = 0.80200
Training:
这段代码展示了在Cora数据集上运行GAT模型的参数配置和训练结果。从输出可以看出,模型在训练集上的损失为1.13271,准确率为60%,在验证集上的损失为1.00691,准确率为80.2%。