知行编程网知行编程网  2022-07-13 12:00 知行编程网 隐藏边栏 |   抢沙发  187 
文章评分 0 次,平均分 0.0

LSTM模型结构的可视化

来自 | 知乎  作者 | master苏

链接 | https://zhuanlan.zhihu.com/p/139617364

编辑 | 深度学习这件小事

本文仅作学术交流,如有侵权,请联系后台删除。
最近在学习LSTM应用在时间序列的预测上,但是遇到一个很大的问题就是LSTM在传统BP网络上加上时间步后,其结构就很难理解了,同时其输入输出数据格式也很难理解,网络上有很多介绍LSTM结构的文章,但是都不直观,对初学者是非常不友好的。我也是苦苦冥思很久,看了很多资料和网友分享的LSTM结构图形才明白其中的玄机。

本文内容如下:

   传统的BP网络和CNN网络
BP网络和CNN网络没有时间维,和传统的机器学习算法理解起来相差无几,CNN在处理彩色图像的3通道时,也可以理解为叠加多层,图形的三维矩阵当做空间的切片即可理解,写代码的时候照着图形一层层叠加即可。如下图是一个普通的BP网络和CNN网络。
LSTM模型结构的可视化

BP网络
LSTM模型结构的可视化
CNN网络
图中的隐含层、卷积层、池化层、全连接层等,都是实际存在的,一层层前后叠加,在空间上很好理解,因此在写代码的时候,基本就是看图写代码,比如用keras就是:

   LSTM网络
当我们在网络上搜索看LSTM结构的时候,看最多的是下面这张图:
LSTM模型结构的可视化
RNN网络
这是RNN循环神经网络经典的结构图,LSTM只是对隐含层节点A做了改进,整体结构不变,因此本文讨论的也是这个结构的可视化问题。
中间的A节点隐含层,左边是表示只有一层隐含层的LSTM网络,所谓LSTM循环神经网络就是在时间轴上的循环利用,在时间轴上展开后得到右图。
看左图,很多同学以为LSTM是单输入、单输出,只有一个隐含神经元的网络结构,看右图,以为LSTM是多输入、多输出,有多个隐含神经元的网络结构,A的数量就是隐含层节点数量。
WTH?思维转不过来啊。这就是传统网络和空间结构的思维。
实际上,右图中,我们看Xt表示序列,下标t是时间轴,所以,A的数量表示的是时间轴的长度,是同一个神经元在不同时刻的状态(Ht),不是隐含层神经元个数。
我们知道,LSTM网络在训练时会使用上一时刻的信息,加上本次时刻的输入信息来共同训练。
举个简单的例子:在第一天我生病了(初始状态H0),然后吃药(利用输入信息X1训练网络),第二天好转但是没有完全好(H1),再吃药(X2),病情得到好转(H2),如此循环往复知道病情好转。因此,输入Xt是吃药,时间轴T是吃多天的药,隐含层状态是病情状况。因此我还是我,只是不同状态的我。
实际上,LSTM的网络是这样的:
LSTM模型结构的可视化

LSTM网络结构
上面的图表示包含2个隐含层的LSTM网络,在T=1时刻看,它是一个普通的BP网络,在T=2时刻看也是一个普通的BP网络,只是沿时间轴展开后,T=1训练的隐含层信息H,C会被传递到下一个时刻T=2,如下图所示。上图中向右的五个常常的箭头,所的也是隐含层状态在时间轴上的传递。
LSTM模型结构的可视化
注意,图中H表示隐藏层状态,C是遗忘门,后面会讲解它们的维度。

   LSTM的输入结构

为了更好理解LSTM结构,还必须理解LSTM的数据输入情况。仿照3通道图像的样子,在加上时间轴后的多样本的多特征的不同时刻的数据立方体如下图所示:
LSTM模型结构的可视化

三维数据立方体
右边的图是我们常见模型的输入,比如XGBOOST,lightGBM,决策树等模型,输入的数据格式都是这种(N*F)的矩阵,而左边是加上时间轴后的数据立方体,也就是时间轴上的切片,它的维度是(N*T*F),第一维度是样本数,第二维度是时间,第三维度是特征数,如下图所示:
LSTM模型结构的可视化
这样的数据立方体很多,比如天气预报数据,把样本理解成城市,时间轴是日期,特征是天气相关的降雨风速PM2.5等,这个数据立方体就很好理解了。在NLP里面,一句话会被embedding成一个矩阵,词与词的顺序是时间轴T,索引多个句子的embedding三维矩阵如下图所示:

LSTM模型结构的可视化


   pytorch中的LSTM

4.1 pytorch中定义的LSTM模型

pytorch中定义的LSTM模型的参数如下
结合前面的图形,我们一个个看。
(1)input_size:x的特征维度,就是数据立方体中的F,在NLP中就是一个词被embedding后的向量长度,如下图所示:
LSTM模型结构的可视化
(2)hidden_size:隐藏层的特征维度(隐藏层神经元个数),如下图所示,我们有两个隐含层,每个隐藏层的特征维度都是5。注意,非双向LSTM的输出维度等于隐藏层的特征维度。
LSTM模型结构的可视化
(3)num_layers:lstm隐层的层数,上面的图我们定义了2个隐藏层。
(4)batch_first:用于定义输入输出维度,后面再讲。
(5)bidirectional:是否是双向循环神经网络,如下图是一个双向循环神经网络,因此在使用双向LSTM的时候我需要特别注意,正向传播的时候有(Ht, Ct),反向传播也有(Ht', Ct'),前面我们说了非双向LSTM的输出维度等于隐藏层的特征维度,而双向LSTM的输出维度是隐含层特征数*2,而且H,C的维度是时间轴长度*2。
LSTM模型结构的可视化


4.2 喂给LSTM的数据格式

pytorch中LSTM的输入数据格式默认如下:
前面也说到,如果LSTM的参数 batch_first=True,则要求输入的格式是:
刚好调换前面两个参数的位置。其实这是比较好理解的数据形式,下面以NLP中的embedding向量说明如何构造LSTM的输入。
之前我们的embedding矩阵如下图:
LSTM模型结构的可视化
如果把batch放在第一位,则三维矩阵的形式如下:
LSTM模型结构的可视化
其转换过程如下图所示:
LSTM模型结构的可视化
看懂了吗,这就是输入数据的格式,是不是很简单。
LSTM的另外两个输入是 h0 和 c0,可以理解成网络的初始化参数,用随机数生成即可。
注意,如果我们定义的input格式是:
则H和C的格式也是要变的:


4.3 LSTM的output格式

LSTM的输出是一个tuple,如下:
output的默认维度是:
和input的情况类似,如果我们前面定义的input格式是:
则ht和ct的格式也是要变的:
说了这么多,我们回过头来看看ht和ct在哪里,请看下图:
LSTM模型结构的可视化
output在哪里?请看下图:
LSTM模型结构的可视化

   LSTM和其他网络组合
还记得吗,output的维度等于隐藏层神经元的个数,即hidden_size,在一些时间序列的预测中,会在output后,接上一个全连接层,全连接层的输入维度等于LSTM的hidden_size,之后的网络处理就和BP网络相同了,如下图:
LSTM模型结构的可视化
用pytorch实现上面的结构:
当然,有些模型则是将输出当做另一个LSTM的输入,或者使用隐藏层ht,ct的信息进行建模,不一而足。
好了,以上就是我对LSTM的一些学习心得,看完记得关注点赞。

参考链接:
https://zhuanlan.zhihu.com/p/94757947
https://zhuanlan.zhihu.com/p/59862381
https://zhuanlan.zhihu.com/p/36455374
https://www.zhihu.com/question/41949741/answer/318771336
https://blog.csdn.net/android_ruben/article/details/80206792

<section data-brushtype="text" style="padding-right: 0em;padding-left: 0em;white-space: normal;letter-spacing: 0.544px;color: rgb(62, 62, 62);font-family: "Helvetica Neue", Helvetica, "Hiragino Sans GB", "Microsoft YaHei", Arial, sans-serif;widows: 1;word-spacing: 2px;caret-color: rgb(255, 0, 0);text-align: center;"><strong style="color: rgb(0, 0, 0);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;letter-spacing: 0.544px;"><span style="letter-spacing: 0.5px;font-size: 14px;"><strong style="font-size: 16px;letter-spacing: 0.544px;"><span style="letter-spacing: 0.5px;">—</span></strong>完<strong style="font-size: 16px;letter-spacing: 0.544px;"><span style="letter-spacing: 0.5px;font-size: 14px;"><strong style="font-size: 16px;letter-spacing: 0.544px;"><span style="letter-spacing: 0.5px;">—</span></strong></span></strong></span></strong></section><pre><pre><section style="letter-spacing: 0.544px;white-space: normal;font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;"><section powered-by="xiumi.us"><section style="margin-top: 15px;margin-bottom: 25px;opacity: 0.8;"><section><section style="letter-spacing: 0.544px;"><section powered-by="xiumi.us"><section style="margin-top: 15px;margin-bottom: 25px;opacity: 0.8;"><section><section style="margin-bottom: 15px;padding-right: 0em;padding-left: 0em;color: rgb(127, 127, 127);font-size: 12px;font-family: sans-serif;line-height: 25.5938px;letter-spacing: 3px;text-align: center;"><span style="color: rgb(0, 0, 0);"><strong><span style="font-size: 16px;font-family: 微软雅黑;caret-color: red;">为您推荐</span></strong></span></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;">一文了解深度推荐算法的演进</section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;">干货 | 算法工程师超实用技术路线图</section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;"><span style="font-size: 14px;">13个算法工程师必须掌握的PyTorch Tricks</span></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;"><span style="font-size: 14px;">吴恩达上新:生成对抗网络(GAN)专项课程</span><br  /></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;text-align: center;">拿到2021灰飞烟灭算法岗offer的大佬们是啥样的<span style="font-size: 14px;">?</span></section></section></section></section></section></section></section></section></section>

LSTM模型结构的可视化

本篇文章来源于: 深度学习这件小事

本文为原创文章,版权归所有,欢迎分享本文,转载请保留出处!

知行编程网
知行编程网 关注:1    粉丝:1
这个人很懒,什么都没写

你可能也喜欢

热评文章

发表评论

表情 格式 链接 私密 签到
扫一扫二维码分享