知行编程网知行编程网  2022-05-01 16:00 知行编程网 隐藏边栏 |   抢沙发  160 
文章评分 0 次,平均分 0.0
Attention 扫盲:注意力机制及其 PyTorch 应用实现

来自 | 知乎

作者 | Lucas

地址 | https://zhuanlan.zhihu.com/p/88376673

专栏 | 深度学习与情感计算

编辑 | 机器学习算法与自然语言处理


Attention 扫盲:注意力机制及其 PyTorch 应用实现



仿生人脑注意力模型->计算资源分配

深度学习attention 机制是对人类视觉注意力机制的仿生,本质上是一种资源分配机制。生理原理就是人类视觉注意力能够以高分辨率接收于图片上的某个区域,并且以低分辨率感知其周边区域,并且视点能够随着时间而改变。换而言之,就是人眼通过快速扫描全局图像,找到需要关注的目标区域,然后对这个区域分配更多注意,目的在于获取更多细节信息和抑制其他无用信息。提高 representation 的高效性。例如,对于下面一张图,我的主要关注点就在于中间的 icon 和 ATTENTION 文字,对于边框上的条纹就不太关注,而且看一眼还有点晕。

Attention 扫盲:注意力机制及其 PyTorch 应用实现

Encoder-Decoder框架==sequence to sequence 条件生成框架

Encoder-Decoder框架,也被称为 sequence to sequence 条件生成框架[1],是一种文本处理领域的研究模式。常规的 encoder-decoder方法,第一步,将输入句子序列 X通过神经网络编码为固定长度的上下文向量C,也就是文本的语义表示;第二步,由另外一个神经网络作为解码器根据当前已经预测出来的词记忆编码后的上下文向量 C,来预测目标词序列,过程中编码器和解码器的 RNN 是联合训练的,但是监督信息只出现在解码器 RNN 一端,梯度随着反向传播到编码器 RNN 一端。使用 LSTM 进行文本建模时当前流行的有效方法[2]

attention 机制的最典型应用是统计机器翻译。给定任务,输入是“Echt”, “Dicke” and “Kiste”进 encoder,使用 rnn 表示文本为固定长度向量 h3。但问题就在于,当前 decoder 生成 y1 时仅仅依赖于最后一个隐层状态h3,也就是 sentence_embedding。那么这个 h3 必须 encode 输入句子中的全部信息才行。可实际上,传统Encoder-Decoder模型并不能达到这个功能。那 LSTM [3]不就是用来解决长期依赖信息问题的嘛?但事实上,长短期记忆网络仍然存在问题。我们说,RNN在长期信息访问当前处理单元之前,需要按顺序地通过所有之前的单元。这意味着它很容易遭遇梯度消失问题。然后引入 LSTM,使用门控某种程度上解决这个问题。的确,LSTM、GRU 和其变体能学习大量的长期信息,但它们最多只能记住相对长的信息,而不是更大更长。

Attention 扫盲:注意力机制及其 PyTorch 应用实现
使用 RNN 文本表示与生成

所以,我们来总结一下传统 encoder-decoder的一般范式及其问题:任务是翻译中文“我/爱/赛尔”到英文。传统 encoder-decoder 先把整句话输入进去,编码最后一个词“赛尔”结束之后,使用 RNN生成一个整句话的表示-向量 C,在条件生成时,当翻译到第 2个词“赛尔”的时候,需要退 1 步找到已经预测出来的h_1以及上下文表示 C, 然后 decode 输出。

从注意力均等到注意力集中

在传统Encoder-Decoder 框架下:由解码器根据当前已经预测出来的词记忆编码后的上下文向量 C,来预测目标词序列。也就是说,不论生成那个词,我们使用的句子编码表示 C 都是一样的。换句话说,句子中任意单词对生成某个目标单词P_yi来说影响力都是相同的,也就是注意力均等。很显然这不符合直觉直觉应该:我翻译哪个部分,哪个部分就应该把注意力集中于我的翻译的原文,翻译到第一个词,就应该多关注原文中的第一个词是什么意思。详见伪代码和下图:

P_y3 = F((E<black>,C)
Attention 扫盲:注意力机制及其 PyTorch 应用实现
传统 Encoder-Decoder 框架下的 RNN 进行文本翻译,一直使用同一个 c

接下来观察上下两个图的区别:相同的上下文表示C会替换成根据当前生成单词而不断变化的Ci。

Attention 扫盲:注意力机制及其 PyTorch 应用实现
融合 attention 机制的RNN 模型进行文本翻译每个时刻生成不同的 c

文本翻译过程变为:

P_y3 = F((E<black>,C_2)

Encoder-Decoder框架的代码实现[4]

        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

考虑可解释性

不含注意力模型的传统encoder-decoder 可解释差:对于编码向量中究竟编码了什么信息,如何利用这些信息以及解码器特定行为的原因是什么我们并没有明确的认识。包含注意力机制的结构提供了一张相对简单的方式让我们了解解码器的推理过程以及模型究竟在学习什么内容,学到那些东西。尽管是一种弱可解释性,但是已经 make sense 了。

直面 attention 的核心公式

在预测目标语言的第i个词时,源语言第j个词的权重为  , 权重的大小可i以j 看做是一种源语言和目标语言的软对齐信息。 

总结

使用 attention 方法实际上就在于预测一个目标词 yi 时,自动获取原句中不同位置的语义信息,并给每个位置信息的语义赋予的一个权重,也就是“软”对齐信息,将这些信息整理起来计算对于当前词 yi 的原句向量表示 c_i。

Attention 的 PyTorch应用实现

        return self.out(attn_output), attention # model : [batch_size, num_classes], attention : [batch_size, n_step]


github地址:

https://github.com/zy1996code/nlp_basic_model/blob/master/lstm_attention.py

参考

  1. ^《Neural Network Methods in Natural Language Processing》

  2. ^Sequence to Sequence Learning with Neural Networks https://arxiv.org/pdf/1409.3215.pdf

  3. ^LSTM 扫盲:长短期记忆网络解读及其 PyTorch 实现 https://zhuanlan.zhihu.com/p/86876988

  4. ^The Annotated Transformer https://nlp.seas.harvard.edu/2018/04/03/attention.html


<section style="margin-right: 8px;margin-left: 8px;white-space: normal;color: rgb(0, 0, 0);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;letter-spacing: 0.544px;text-align: center;widows: 1;line-height: 1.75em;"><strong><span style="letter-spacing: 0.5px;font-size: 14px;"><strong style="font-size: 16px;letter-spacing: 0.544px;"><span style="letter-spacing: 0.5px;">—</span></strong>完<strong style="font-size: 16px;letter-spacing: 0.544px;"><span style="letter-spacing: 0.5px;font-size: 14px;"><strong style="font-size: 16px;letter-spacing: 0.544px;"><span style="letter-spacing: 0.5px;">—</span></strong></span></strong></span></strong></section><section style="white-space: normal;font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;letter-spacing: 0.544px;text-align: center;widows: 1;color: rgb(255, 97, 149);"><section powered-by="xiumi.us"><section style="margin-top: 15px;margin-bottom: 25px;opacity: 0.8;"><section><section style="letter-spacing: 0.544px;"><section powered-by="xiumi.us"><section style="margin-top: 15px;margin-bottom: 25px;opacity: 0.8;"><section><section style="margin-right: 8px;margin-bottom: 15px;margin-left: 8px;padding-right: 0em;padding-left: 0em;color: rgb(127, 127, 127);font-family: sans-serif;font-size: 12px;line-height: 25.5938px;letter-spacing: 3px;"><span style="color: rgb(0, 0, 0);"><strong><span style="font-size: 16px;font-family: 微软雅黑;caret-color: red;">为您推荐</span></strong></span></section><p style="margin-right: 8px;margin-bottom: 5px;margin-left: 8px;padding-right: 0em;padding-left: 0em;color: rgb(127, 127, 127);font-family: sans-serif;font-size: 12px;line-height: 1.75em;letter-spacing: 0px;">神经网络激励函数的作用是什么?有没有形象的解释?<br  /></p><p style="margin-right: 8px;margin-bottom: 5px;margin-left: 8px;padding-right: 0em;padding-left: 0em;color: rgb(127, 127, 127);font-family: sans-serif;font-size: 12px;line-height: 1.75em;letter-spacing: 0px;">“12306”的架构到底有多牛逼?<br  /></p><p style="margin-right: 8px;margin-bottom: 5px;margin-left: 8px;padding-right: 0em;padding-left: 0em;color: rgb(127, 127, 127);font-family: sans-serif;font-size: 12px;line-height: 1.75em;letter-spacing: 0px;">网传饶毅举报多位学者论文造假?官方回应了<br  /></p><p style="margin-right: 8px;margin-bottom: 5px;margin-left: 8px;padding-right: 0em;padding-left: 0em;color: rgb(127, 127, 127);font-family: sans-serif;font-size: 12px;line-height: 1.75em;letter-spacing: 0px;">阿里如何抗住90秒100亿?看这篇你就明白了!<br  /></p><p style="margin-right: 8px;margin-bottom: 5px;margin-left: 8px;padding-right: 0em;padding-left: 0em;color: rgb(127, 127, 127);font-family: sans-serif;font-size: 12px;line-height: 1.75em;letter-spacing: 0px;">深度学习必懂的13种概率分布</p></section></section></section></section></section></section></section></section>

本篇文章来源于: 深度学习这件小事

本文为原创文章,版权归所有,欢迎分享本文,转载请保留出处!

知行编程网
知行编程网 关注:1    粉丝:1
这个人很懒,什么都没写

发表评论

表情 格式 链接 私密 签到
扫一扫二维码分享