循环神经网络模型详解:从RNN到LSTM
循环神经网络模型详解:从RNN到LSTM
进入这篇文章之前,想必大家已经阅读过前面的系列文章:
- 【通俗易懂说模型】线性回归(附深度学习、机器学习发展史)
- 【通俗易懂说模型】非线性回归和逻辑回归(附神经网络图详细解释)
- 【通俗易懂说模型】反向传播(附多元分类与Softmax函数)
- 【通俗易懂说模型】卷积神经网络(呕心沥血版)
- 【通俗易懂说模型】一篇弄懂几个经典CNN图像模型(AlexNet、VGGNet、ResNet)
通过前面的文章,相信读者对深度学习、机器学习已经有了一个较为全面且细致的理解。接下来,本文将基于前面提到的回归、反向传播、卷积神经网络等知识,从深度学习在图像识别领域发展的历史脉络出发,带你走入序列转序列模型。从图像识别中的卷积神经网络转入序列模型中的循环神经网络模型模型,学习人工智能中的奇思妙想,感悟前辈伟人的思想精华。
2. 循环神经网络模型
前面我们学习过卷积神经网络模型(CNN),就是如下这个模型,CNN模型是基于卷积操作完成对图像得特征提取。(总结一下,前面我们已经学过两种模型:1、连接模型/全连接等,用于回归、分类操作;2、卷积神经网络模型,用于图像特征提取)
序列转序列模型的就是记忆+理解。如果我们能把语言记忆下来(语言前后能够联系起来),同时理解语言的句子,那么我们就能够完成翻译或者其他操作。之前学习的卷积神经网络受到了生物视觉细胞的启发,相似地,循环神经网络受到了生物记忆能力的启发。循环神经网络是具有循环结构的一类神经网络,我们又称之为RNN,此外还有RNN的加强版LSTM和GRU,它们都拥有更强的“记忆力”。接下来,我们重点对RNN进行讲解。
循环神经网络(RNN):针对 序列转序列模型
语言理解的核心:记忆
2.1 RNN模型
2.1.1 RNN出现之因
要理解RNN出现的原因,我们必须要理解CNN存在的缺点。细细研究上图,我们会发现, 他们的输出不会受之前输出的影响,仅仅受输入特征值的影响,即隐藏层之间没有连接(每一个隐藏层的块表示每一个时刻的输出值)。总之,CNN考虑不到时间维度上的影响,仅仅能考虑一个时间点,事物不同特征值的输入。那么,对于简单的猫,狗,手写数字等单个物体的识别具有较好的效果.。但是, 对于一些与时间先后有关的, 比如视频的下一时刻的预测,文档前后文内容的预测等, 这些算法的表现就不尽如人意了。因此, RNN就应运而生了。RNN每个时间点的输出受到之前所有时间点输出的影响,然后综合考虑这些信息做出下一个输出
2.1.2 什么是RNN
RNN是一种特殊的神经网络结构, 它是根据"人的认知是基于过往的经验和记忆"这一观点提出的. 它与DNN,CNN不同的是: 它不仅考虑前一时刻的输入,而且赋予了网络对前面的内容的一种'记忆'功能。
结合现实来看,我们每一个人的性格特点都是由以往所有的经历所造成的。离现在时间点越久远的经历对现在的我们影响越小,而发生在最近的经历对我们的影响越大。
RNN之所以称为循环神经网路,即一个序列当前的输出与前面的输出也有关。具体的表现形式为网络会对前面的信息进行记忆并应用于当前输出的计算中,即隐藏层之间的节点不再无连接而是有连接的,并且隐藏层的输入不仅包括输入层的输出还包括上一时刻隐藏层的输出,即隐藏层之间还存在先后的时间序关系
如下图所示,普通神经网络的数据是单向传递的,而循环神经网络的数据是循环传递的,输入层x经过隐含层之后输出y,而隐含层输出的结果h需要作为下一次输入的一部分,循环传递。
循环神经网络可以展开成一连串相互连接的前向网络,如下图中的等式右侧所示。假如我们要输入的序列是(xo,x1,...,xt),x输入隐含层后输出结果y0和隐藏向量h0,接着将h0当作第二次输入的一部分与x一起输入隐含层,得到输出结果y1和隐藏向量h1,以此类推。这么做的目的是将前面时刻的输入信息通过隐藏向量传递到后面时刻,这样网络就有了一定的“记忆力”。隐藏向量不断循环传递信息,所以被称为循环神经网络。
例如,我们可以将循环神经网络运用到句子预测上面。我们输入“我会讲中”来让计算机自动预
测出“文”字。如下图所示,“我”“会”“讲”“中”被当作一个输入序列依次输入循环神经网络,“我”“会”“讲”3个字的历史信息通过隐藏向量h0、h1和h2依次传递到最后,由h2和“中”作为最后的输入值,并输出预测的“文”字。
下面我们利用PyTorch的nn.RNN模块去感性认识一下循环神经网络。我们先通过以下代码初始化一个RNN单元,该单元的输入特征维度为5,隐藏向量的特征维度为7,结构如下图所示:
from torch import nn
rnn_cell = nn.RNNCell(input_size=5, hidden_size=7)
rnn_cell
为了更好地理解,下面给出RNN的数学表达:
其中,ht-1为前一时刻输出的隐藏向量,xt为该时刻的输入向量。我们从公式中可以看出,RNN的数学本质是将xt和ht-1分别进行线性变换并相加,随后经过非线性层,如tanh函数,得到该时刻的隐藏向量ht。(这里隐藏向量就是暂时的输出向量)
我们可以通过RNNCell的权值属性weight_ih、weight_hh、bias_ih和bias_hh来访问公式中的Win、Whh、bin和bh值:
RNNCell需要两个输入:一是输入向量xt,它是格式为(batch,input_size)的Tensor,我们假设批量数为1,输人向量的维度为5,随机初始化一个输入向量;二是上一个时刻的隐藏向量ht-1
其格式是(batch,hidden_size),我们假设批量数为1,隐藏向量维度为7,随机初始化一个隐藏向量,相关代码如下:
我们可以从上述代码中看到输出的隐藏向量ht。RNNCell在处理一个序列的输入向量时,必须采用循环的方式逐个向量进行输入。而PyTorch为我们定义了一个RNN模块,可以直接将序列当成输入。我们先初始化一个RNN:
import torch
from torch import nn
rnn = nn.RNN(input_size=5, hidden_size=7)
input = torch.randn(3,2,5)
hidden = torch.randn(1,2,7)
rnn(input,hidden)
RNN也需要两个输入:一是输入向量序列x,默认情况下,它是格式(seq,batch,input_size)的Tensor;二是上一个时刻的隐藏向量ht-1,其格式为(layers*direction,batch,hidden_size)。layers表示RNN的隐藏节点的层数,direction表示RNN的方向,默认情况下两者均为1。下面我们初始化一个批量数为2、序列长度为3、特征维度数为5的输入以及一个批量数为2、维度为7的隐藏向量:
从上面的代码可以看出,RNN和RNNCeII的不同在于它可以同时处理一串序列,并且同时返回输出向量序列和隐藏向量。在实际使用的过程中,可以根据需要和习惯选择其中一种。
2.1.3 RNN存在的问题
RNN存在长期依赖问题:长期依赖问题是指RNN难以捕捉长期之前的依赖关系。换句话说,很久之前的数据在RNN中就会被彻底遗忘。并且这个长期实际上并不长,差不多到50个词语RNN就会彻底遗忘前面的数据,这就导致RNN很难处理长文本
举个例子:
1、如果从“这块冰糖味道真?”来预测下一个词,RNN是很容易得出“甜”结果的
2、如果有这么一句话,“他吃了一口菜,被辣的流出了眼泪,满脸通红。旁边的人赶紧给他倒了一杯凉水,他咕咚咕咚喝了两口,才逐渐恢复正常。他气愤地说道:这个菜味道真?”对于这句话,利用RNN来处理时,RNN在语句最后时早就忘了开头的信息“辣的流出来眼泪”,所以它难以成功预测出结果
因此,RNN难以处理具有长期依赖关系的问题
RNN难以处理长期依赖问题的根本原因在于:经过多层多阶段传播后存在梯度消失(大部分,模型仍可用但是长期依赖会遗忘)或梯度爆炸(很少,但是一遇到模型就彻底完蛋)问题
2.1.4 RNN梯度消失、爆炸的原因
假设时间序列只有三段,初始进位量为S0,则RNN的前向传播过程为:
再假设损失函数为:
训练RNN模型的本质就是求解Wx、Ws和Wo的值。进行反向传播求解模型:求解Lt取最小时,Wx、Ws和Wo的值。于是,对Lt求这三个值的偏导有:
可以看到:
1、Lt对W0没有长期依赖问题
2、但是对Wx、Ws求偏导时,偏导值不仅与当前的输入Xt有关,也与X0-Xt-1有关。并且这个关系是偏导间积的关系
再将这个序列的段数扩大到n段(不再是三段),那么此时求偏导后值为:
Sj是每个时间步结合当前输入以及前时刻输入得到的信息总和,为提高信息的非线性成分以及对不必要信息进行剔除以降低梯度爆炸或消失,我们要在每一步信息传递给下一步时间块前,利用激活函数对信息进行激活处理,如下:
那么此时(划重点):
1、如果激活函数tanh的导数小于1,那么随着累乘的增加,RNN会出现梯度消失的情况;如果激活函数tanh的导数大于1,那么随着累乘的增加,RNN会出现梯度爆炸的情况
2.2 LSTM模型
LSTM是Long-ShortTermMemory的缩写,中文名叫长短期记忆网络,可以将它看作RNN的改进版本。传统的RNN模型在处理长序列时常常出现“梯度消失”的问题,为了让网络能够更好地“记住”以前的信息,Hochreiter和Schmidhuber提出了LSTM模型,经过改良的LSTM在很多方面取得了相当巨大的成功。LSTM的结构如下图所示:
如上图所示,从表面上看,LSTM与RNN模块的不同在于LSTM的输入有3个:ht-1、Ct-1和xt,输出为(ht,ct)。我们先不研究LSTM的内部结构,看一下这段代码:
lstm_cell = nn.LSTMCell(input_size=5,hidden_size=7)
lstm_cell
input = torch.randn(1,5)
h0 = torch.randn(1,7)
c0 = torch.randn(1,7)
h1,c1 = lstm_cell(input,(h0,c0))
print(h1)
print(c1)
同样,PyTorch为我们定义了一个LSTM模块,可以直接将序列当成输入。我们先初始化一个LSTM:
lstm = nn.LSTM(input_size=5,hidden_size=7)
input = torch.randn(3,2,5)
h0 = torch.randn(1,2,7)
c0 = torch.randn(1,2,7)
output,(h1,c1) = lstm(input, (h0,c0))
print(output.size())
print(h1.size())
print(c1.size())
3. 总结
如果想学习更多深度学习文章,可以订阅一下热门专栏:
- 深度学习
- PyTorch实战深度学习80例
- 零基础入门PyTorch框架
如果想要学习更多pyTorch/python编程的知识,大家可以点个关注并订阅,持续学习、天天进步