知行编程网知行编程网  2022-07-05 12:00 知行编程网 隐藏边栏 |   抢沙发  2 
文章评分 0 次,平均分 0.0

如何从NumPy直接创建RNN?

木易 发自 凹非寺 
来自 | 量子位

使用成熟的Tensorflow、PyTorch框架去实现递归神经网络(RNN),已经极大降低了技术的使用门槛。

但是,对于初学者,这还是远远不够的。知其然,更需知其所以然。

如何从NumPy直接创建RNN?

要避免低级错误,打好理论基础,然后使用RNN去解决更多实际的问题的话。

那么,有一个有趣的问题可以思考一下:

不使用Tensorflow等框架,只有Numpy的话,你该如何构建RNN?

没有头绪也不用担心。这里便有一项教程:使用Numpy从头构建用于NLP领域的RNN。

可以带你行进一遍RNN的构建流程。

   初始化参数

与传统的神经网络不同,RNN具有3个权重参数,即:

输入权重(input weights),内部状态权重(internal state weights)和输出权重(output weights)

首先用随机数值初始化上述三个参数。

之后,将词嵌入维度(word_embedding dimension)和输出维度(output dimension)分别初始化为100和80。

输出维度是词汇表中存在的唯一词向量的总数。

, (output_dim,hidden_dim))

变量prev_memory指的是internal_state(这些是先前序列的内存)。

其他参数也给予了初始化数值。

input_weight梯度,internal_state_weight梯度和output_weight梯度分别命名为dU,dW和dV。

变量bptt_truncate表示网络在反向传播时必须回溯的时间戳数,这样做是为了克服梯度消失的问题。


   前向传播

输出和输入向量

例如有一句话为:I like to play.,则假设在词汇表中:

I被映射到索引2,like对应索引45,to对应索引10、**对应索引64而标点符号.** 对应索引1。

为了展示从输入到输出的情况,我们先随机初始化每个单词的词嵌入。

输入已经完成,接下来需要考虑输出。

在本项目中,RNN单元接受输入后,输出的是下一个最可能出现的单词。

用于训练RNN,在给定第t+1个词作为输出的时候将第t个词作为输入,例如:在RNN单元输出字为“like”的时候给定的输入字为“I”.

现在输入是嵌入向量的形式,而计算损失函数(Loss)所需的输出格式是独热编码(One-Hot)矢量。

这是对输入字符串中除第一个单词以外的每个单词进行的操作,因为该神经网络学习只学习的是一个示例句子,而初始输入是该句子的第一个单词。

RNN的黑箱计算

现在有了权重参数,也知道输入和输出,于是可以开始前向传播的计算。

训练神经网络需要以下计算:

如何从NumPy直接创建RNN?

其中:

U代表输入权重、W代表内部状态权重,V代表输出权重。

输入权重乘以input(x),内部状态权重乘以前一层的激活(prev_memory)。

层与层之间使用的激活函数用的是tanh。

计算损失函数

之后损失函数使用的是交叉熵损失函数,由下式给出:

如何从NumPy直接创建RNN?


最重要的是,我们需要在上面的代码中看到第5行。

正如所知,ground_truth output(y)的形式是[0,0,….,1,…0]和predicted_output(y^hat)是[0.34,0.03,……,0.45]的形式,我们需要损失是单个值来从它推断总损失。

为此,使用sum函数来获得特定时间戳下y和y^hat向量中每个值的误差之和。

total_loss是整个模型(包括所有时间戳)的损失。

   反向传播

反向传播的链式法则:

如何从NumPy直接创建RNN?

如上图所示:

Cost代表误差,它表示的是y^hat到y的差值。

由于Cost是的函数输出,因此激活a所反映的变化由dCost/da表示。

实际上,这意味着从激活节点的角度来看这个变化(误差)值。

类似地,a相对于z的变化表示为da/dz,z相对于w的变化表示为dw/dz。

最终,我们关心的是权重的变化(误差)有多大。

如何从NumPy直接创建RNN?

而由于权重与Cost之间没有直接关系,因此期间各个相对的变化值可以直接相乘(如上式所示)。

RNN的反向传播

由于RNN中存在三个权重,因此我们需要三个梯度。input_weights(dLoss / dU),internal_state_weights(dLoss / dW)和output_weights(dLoss / dV)的梯度。

这三个梯度的链可以表示如下:

如何从NumPy直接创建RNN?

所述dLoss/dy_unactivated代码如下:

计算两个梯度函数,一个是multiplication_backward,另一个是additional_backward。

在multiplication_backward的情况下,返回2个参数,一个是相对于权重的梯度(dLoss / dV),另一个是链梯度(chain gradient),该链梯度将成为计算另一个权重梯度的链的一部分。

在addition_backward的情况下,在计算导数时,加法函数(ht_unactivated)中各个组件的导数为1。例如:dh_unactivated / dU_frd=1(h_unactivated = U_frd + W_frd),且dU_frd / dU_frd的导数为1。

所以,计算梯度只需要这两个函数。multiplication_backward函数用于包含向量点积的方程,addition_backward用于包含两个向量相加的方程。

如何从NumPy直接创建RNN?


至此,已经分析并理解了RNN的反向传播,目前它是在单个时间戳上实现它的功能,之后可以将其用于计算所有时间戳上的梯度。

如下面的代码所示,forward_params_t是一个列表,其中包含特定时间步长的网络的前向参数。

变量ds是至关重要的部分,因为此代码考虑了先前时间戳的隐藏状态,这将有助于提取在反向传播时所需的信息。

对于RNN,由于存在梯度消失的问题,所以采用的是截断的反向传播,而不是使用原始的。

在此技术中,当前单元将只查看k个时间戳,而不是只看一次时间戳,其中k表示要回溯的先前单元的数量。


   权重更新

一旦使用反向传播计算了梯度,则更新权重势在必行,而这些是通过批量梯度下降法


   训练序列

完成了上述所有步骤,就可以开始训练神经网络了。

用于训练的学习率是静态的,还可以使用逐步衰减等更改学习率的动态方法。

)

恭喜你!你现在已经实现从头建立递归神经网络了!

那么,是时候了,继续向LSTM和GRU等的高级架构前进吧。

原文链接:
https://medium.com/@rndholakia/implementing-recurrent-neural-network-using-numpy-c359a0a68a67

<section data-brushtype="text" style="padding-right: 0em;padding-left: 0em;white-space: normal;max-width: 100%;letter-spacing: 0.544px;color: rgb(62, 62, 62);font-family: "Helvetica Neue", Helvetica, "Hiragino Sans GB", "Microsoft YaHei", Arial, sans-serif;widows: 1;word-spacing: 2px;caret-color: rgb(255, 0, 0);text-align: center;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;color: rgb(0, 0, 0);font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;box-sizing: border-box !important;overflow-wrap: break-word !important;">—</span></strong>完<strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;font-size: 14px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;font-size: 16px;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;letter-spacing: 0.5px;box-sizing: border-box !important;overflow-wrap: break-word !important;">—</span></strong></span></strong></span></strong></section><pre style="padding-right: 0em;padding-left: 0em;max-width: 100%;letter-spacing: 0.544px;color: rgb(62, 62, 62);widows: 1;word-spacing: 2px;caret-color: rgb(255, 0, 0);text-align: center;box-sizing: border-box !important;overflow-wrap: break-word !important;"><pre style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;letter-spacing: 0.544px;white-space: normal;font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section powered-by="xiumi.us" style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-top: 15px;margin-bottom: 25px;max-width: 100%;opacity: 0.8;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;letter-spacing: 0.544px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section powered-by="xiumi.us" style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-top: 15px;margin-bottom: 25px;max-width: 100%;opacity: 0.8;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><section style="margin-bottom: 15px;padding-right: 0em;padding-left: 0em;max-width: 100%;color: rgb(127, 127, 127);font-size: 12px;font-family: sans-serif;line-height: 25.5938px;letter-spacing: 3px;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;color: rgb(0, 0, 0);box-sizing: border-box !important;overflow-wrap: break-word !important;"><strong style="max-width: 100%;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="max-width: 100%;font-size: 16px;font-family: 微软雅黑;caret-color: red;box-sizing: border-box !important;overflow-wrap: break-word !important;">为您推荐</span></strong></span></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;">干货 | 算法工程师超实用技术路线图</section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="font-size: 14px;">13个算法工程师必须掌握的PyTorch Tricks</span><span style="letter-spacing: 0.544px;font-family: -apple-system-font, system-ui, "Helvetica Neue", "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei UI", "Microsoft YaHei", Arial, sans-serif;"></span></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;"><span style="font-size: 14px;">吴恩达上新:生成对抗网络(GAN)专项课程</span><br  /></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;box-sizing: border-box !important;overflow-wrap: break-word !important;">拿到2021灰飞烟灭算法岗offer的大佬们是啥样的<span style="font-size: 14px;">?</span><br  /></section><section style="margin-top: 5px;margin-bottom: 5px;padding-right: 0em;padding-left: 0em;max-width: 100%;min-height: 1em;font-family: sans-serif;letter-spacing: 0px;opacity: 0.8;line-height: normal;color: rgb(0, 0, 0);box-sizing: border-box !important;overflow-wrap: break-word !important;">你一定从未看过如此通俗易懂的YOLO系列解读 (下)</section></section></section></section></section></section></section></section></section>

如何从NumPy直接创建RNN?

本篇文章来源于: 深度学习这件小事

本文为原创文章,版权归所有,欢迎分享本文,转载请保留出处!

知行编程网
知行编程网 关注:1    粉丝:1
这个人很懒,什么都没写

你可能也喜欢

热评文章

发表评论

表情 格式 链接 私密 签到
扫一扫二维码分享