知行编程网知行编程网  2022-06-17 11:00 知行编程网 隐藏边栏 |   抢沙发  19 
文章评分 0 次,平均分 0.0

深入理解图注意力机制

作者丨张昊、李牧非、王敏捷、张峥
来源丨https://zhuanlan.zhihu.com/p/57168713
编辑 | 极市平台

图卷积网络(GCN)告诉我们,将局部的图结构和节点特征结合可以在节点分类任务中获得不错的表现。美中不足的是GCN结合邻近节点特征的方式和图的结构依依相关,这局限了训练所得模型在其他图结构上的泛化能力。

Graph Attention Network (GAT)提出了用注意力机制对邻近节点特征加权求和。邻近节点特征的权重完全取决于节点特征,独立于图结构。

在这个教程里我们将:

1、解释什么是Graph Attention Network
2、演示用DGL实现这一模型
3、深入理解学习所得的注意力权重
4、初探归纳学习(inductive learning)

难度:★★★★✩ (需要对图神经网络训练和Pytorch有基本了解)

在GCN里引入注意力机制

GAT和GCN的核心区别在于如何收集并累和距离为1的邻居节点的特征表示。在GCN里,一次图卷积操作包含对邻节点特征的标准化求和:

深入理解图注意力机制

其中  是对节点距离为1邻节点的集合。我们通常会加一条连接节点  和它自身的边使得  本身也被包括在里。 是一个基于图结构的标准化常数; 是一个激活函数 (GCN使用了ReLU); 是节点特征转换的权重矩阵,被所有节点共享。由于  和图的机构相关,使得在一张图上学习到的GCN模型比较难直接应用到另一张图上。解决这一问题的方法有很多,比如GraphSAGE提出了一种采用相同节点特征更新规则的模型,唯一的区别是他们将  设为了  。

图注意力模型GAT用注意力机制替代了图卷积中固定的标准化操作。以下图和公式定义了如何对第  层节点特征做更新得到第  层节点特征:

深入理解图注意力机制

注意力网络示意图和更新公式

对于上述公式的一些解释:

公式(1)对层节点嵌入做了线性变换,是该变换可训练的参数。
公式(2)计算了成对节点间的原始注意力分数。它首先拼接了两个节点的嵌入,注意在这里表示拼接;随后对拼接好的嵌入以及一个可学习的权重向量做点积;最后应用了一个LeakyReLU激活函数。这一形式的注意力机制通常被称为_加性注意力_,区别于Transformer里的点积注意力。
公式(3)对于一个节点所有入边得到的原始注意力分数应用了一个softmax操作,得到了注意力权重。
公式(4)形似GCN的节点特征更新规则,对所有邻节点的特征做了基于注意力的加权求和。

出于简洁的考量,在本教程中,我们选择省略了一些论文中的细节,如dropout, skip connection等等。感兴趣的读者们欢迎参阅文末链接的模型完整实现。本质上,GAT只是将原本的标准化常数替换为使用注意力权重的邻居节点特征聚合函数。

GAT的DGL实现

以下代码给读者提供了在DGL里实现一个GAT层的总体印象。别担心,我们会将以下代码拆分成三块,并逐块讲解每块代码是如何实现上面的一条公式。

实现公式(1)

深入理解图注意力机制

第一个公式相对比较简单。线性变换非常常见。在PyTorch里,我们可以通过torch.nn.Linear很方便地实现。

实现公式(2)

深入理解图注意力机制

原始注意力权重  是基于一对邻近节点  和  的表示计算得到。我们可以把注意力权重  看成在 i->j 这条边的数据。因此,在DGL里,我们可以使用 g.apply_edges 这一API来调用边上的操作,用一个边上的用户定义函数来指定具体操作的内容。我们在用户定义函数里实现了公式(2)的操作:

公式中的点积同样借由PyTorch的一个线性变换 attn_fc 实现。注意 apply_edges 会把所有边上的数据打包为一个张量,这使得拼接和点积可以并行完成。

实现公式(3)和(4)

深入理解图注意力机制

类似GCN,在DGL里我们使用update_all API来触发所有节点上的消息传递函数。update_all接收两个用户自定义函数作为参数。message_function发送了两种张量作为消息:消息原节点的表示以及每条边上的原始注意力权重。reduce_function随后进行了两项操作:

1、使用softmax归一化注意力权重 (公式(3))。
2、使用注意力权重聚合邻节点特征 (公式(4))。

这两项操作都先从节点的 mailbox 获取了数据,随后在数据的第二维( dim = 1 ) 上进行了运算。注意数据的第一维代表了节点的数量,第二维代表了每个节点收到消息的数量。

多头注意力 (Multi-head attention)

神似卷积神经网络里的多通道,GAT引入了多头注意力来丰富模型的能力和稳定训练的过程。每一个注意力的头都有它自己的参数。如何整合多个注意力机制的输出结果一般有两种方式:

  • 拼接:
  • 平均:

以上式子中是注意力头的数量。作者们建议对中间层使用拼接对最后一层使用求平均。

我们之前有定义单头注意力的GAT层,它可作为多头注意力GAT层的组建单元:

在Cora数据集上训练一个GAT模型

Cora是经典的文章引用网络数据集。Cora图上的每个节点是一篇文章,边代表文章和文章间的引用关系。每个节点的初始特征是文章的词袋(Bag of words)表示。其目标是根据引用关系预测文章的类别(比如机器学习还是遗传算法)。在这里,我们定义一个两层的GAT模型:

我们使用DGL自带的数据模块加载Cora数据集。

模型训练的流程和GCN教程里的一样。

可视化并理解学到的注意力

1、Cora数据集

以下表格总结了GAT论文以及dgl实现的模型在Cora数据集上的表现:

深入理解图注意力机制

可以看到DGL能完全复现原论文中的实验结果。对比图卷积网络GCN,GAT在Cora上有2~3个百分点的提升。

不过,我们的模型究竟学到了怎样的注意力机制呢?

由于注意力权重与图上的边密切相关,我们可以通过给边着色来可视化注意力权重。以下图片中我们选取了Cora的一个子图并且在图上画出了GAT模型最后一层的注意力权重。我们根据图上节点的标签对节点进行了着色,根据注意力权重的大小对边进行了着色(可参考图右侧的色条)。

深入理解图注意力机制

Cora数据集上学习到的注意力权重

乍看之下模型似乎学到了不同的注意力权重。为了对注意力机制有一个全局观念,我们衡量了注意力分布的熵。对于节点,  构成了一个在邻节点上的离散概率分布。它的熵被定义为:

深入理解图注意力机制

直观的说,熵低代表了概率高度集中,反之亦然。熵为则所有的注意力都被放在一个点上。均匀分布具有最高的熵(  )。在理想情况下,我们想要模型习得一个熵较低的分布(即某一、两个节点比其它节点重要的多)。注意由于节点的入度不同,它们注意力权重的分布所能达到的最大熵也会不同。

基于图中所有节点的熵,我们画了所有头注意力的直方图。

深入理解图注意力机制

Cora数据集上学到的注意力权重直方图

作为参考,下图是在所有节点的注意力权重都是均匀分布的情况下得到的直方图。

深入理解图注意力机制

出人意料的,模型学到的节点注意力权重非常接近均匀分布(换言之,所有的邻节点都获得了同等重视)。这在一定程度上解释了为什么在Cora上GAT的表现和GCN非常接近(在上面表格里我们可以看到两者的差距平均下来不到)。由于没有显著区分节点,注意力并没有那么重要。

这是否说明了注意力机制没什么用?不!在接下来的数据集上我们观察到了完全不同的现象。

2、蛋白质交互网络 (PPI)

PPI(蛋白质间相互作用)数据集包含了24张图,对应了不同的人体组织。节点最多可以有121种标签(比如蛋白质的一些性质、所处位置等)。因此节点标签被表示为有个121元素的二元张量。数据集的任务是预测节点标签。

我们使用了20张图进行训练,2张图进行验证,2张图进行测试。平均下来每张图有2372个节点。每个节点有50个特征,包含定位基因集合、特征基因集合以及免疫特征。至关重要的是,测试用图在训练过程中对模型完全不可见。这一设定被称为归纳学习。

我们比较了dgl实现的GAT和GCN在10次随机训练中的表现。模型的超参数在验证集上进行了优化。在实验中我们使用了micro f1 score来衡量模型的表现。

深入理解图注意力机制

在训练过程中,我们使用了 BCEWithLogitsLoss 作为损失函数。下图绘制了GAT和GCN的学习曲线;显然GAT的表现远优于GCN。

深入理解图注意力机制

PPI数据集上GCN和GAT学习曲线比较

像之前一样,我们可以通过绘制节点注意力分布之熵的直方图来有一个统计意义上的直观了解。以下我们基于一个3层GAT模型中不同模型层不同注意力头绘制了直方图。

第一层学到的注意力:

深入理解图注意力机制

第二层学到的注意力:

深入理解图注意力机制

最后一层学到的注意力:

深入理解图注意力机制

作为参考,下图是在所有节点的注意力权重都是均匀分布的情况下得到的直方图。

深入理解图注意力机制

可以很明显地看到,GAT在PPI上确实学到了一个尖锐的注意力权重分布。与此同时,GAT层与层之间的注意力也呈现出一个清晰的模式:在中间层随着层数的增加注意力权重变得愈发集中;最后的输出层由于我们对不同头结果做了平均,注意力分布再次趋近均匀分布。

不同于在Cora数据集上非常有限的收益,GAT在PPI数据集上较GCN和其它图模型的变种取得了明显的优势(根据原论文的结果在测试集上的表现提升了至少20%)。我们的实验揭示了GAT学到的注意力显著区别于均匀分布。虽然这值得进一步的深入研究,一个由此而生的假设是GAT的优势在于处理更复杂领域结构的能力。

拓展阅读

到目前为止我们演示了如何用DGL实现GAT。简介起见,我们忽略了dropout, skip connection等一些细节。这些细节很常见且独立于DGL相关的概念。有兴趣的读者欢迎参阅完整的代码实现。

1、经过优化的完整代码实现:https://github.com/dmlc/dgl/blob/master/examples/pytorch/gat/gat.py
2、在下一个教程中我们将介绍如何通过并行多头注意力和稀疏矩阵向量乘法来加速GAT模型,敬请期待!


<section style="white-space: normal;line-height: 1.75em;text-align: center;"><strong style="color: rgb(0, 0, 0);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;letter-spacing: 0.544px;widows: 1;background-color: rgb(255, 255, 255);font-size: 16px;max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;box-sizing: border-box !important;overflow-wrap: break-word !important;">—</span></strong>完<strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;box-sizing: border-box !important;overflow-wrap: break-word !important;">—</span></strong></span></strong></span></strong></section><pre><pre style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;letter-spacing: 0.544px;white-space: normal;font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;widows: 1;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section powered-by="xiumi.us" style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-top: 15px;margin-bottom: 25px;max-width: 100%;opacity: 0.8;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section powered-by="xiumi.us" style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-top: 15px;margin-bottom: 25px;max-width: 100%;opacity: 0.8;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section><section style="margin-bottom: 15px;padding-right: 0em;padding-left: 0em;max-width: 100%;color: rgb(127, 127, 127);font-size: 12px;font-family: sans-serif;line-height: 25.5938px;letter-spacing: 3px;text-align: center;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;color: rgb(0, 0, 0);box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;font-size: 16px;font-family: 微软雅黑;caret-color: red;box-sizing: border-box !important;overflow-wrap: break-word !important;">为您推荐</span></strong></span></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;box-sizing: border-box !important;overflow-wrap: break-word !important;">长尾分布下图像分类问题最新综述(2019-2020)</section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="font-size: 14px;">LSTM终获正名,获IEEE 2021神经网络先驱奖!</span><br  /></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;box-sizing: border-box !important;overflow-wrap: break-word !important;">美国宣布“清洁网络”行动,限制7家中国科技公司<br  /></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;box-sizing: border-box !important;overflow-wrap: break-word !important;">数据分析入门常用的23个牛逼Pandas代码</section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="color: rgb(87, 107, 149);font-size: 14px;">如何在科研论文中画出漂亮的插图?</span><br  /></section></section></section></section></section></section></section></section></section>
深入理解图注意力机制

本篇文章来源于: 深度学习这件小事

本文为原创文章,版权归所有,欢迎分享本文,转载请保留出处!

知行编程网
知行编程网 关注:1    粉丝:1
这个人很懒,什么都没写

发表评论

表情 格式 链接 私密 签到
扫一扫二维码分享