LSTM-循环神经网络详解
LSTM-循环神经网络详解
LSTM(长短时记忆网络)是应用最广泛的循环神经网络之一,特别适用于处理具有长期依赖性的问题。本文将详细介绍LSTM的前向传播和反向传播过程,并通过一个具体的Python代码示例,展示如何使用LSTM拟合一个函数。
一、LSTM前向传播
LSTM的模型结构类似于数字电路,按时间维度展开后的模型如下图所示:
LSTM引入了“门”的概念,通过“与门”、“或门”、“异或门”等逻辑门的组合,可以实现复杂的计算功能。在LSTM中,这些“门”是通过软件实现的,类似于《三体》中牛顿和冯·诺依曼利用3000万士兵组成的人肉电脑。
LSTM中的“遗忘门”(用符号ft表示)是一个简单的全连接神经网络,其计算公式如下:
其中,σ代表Sigmoid函数,ht-1是上一个序列的输出,xt为本次输入。遗忘门的作用是决定哪些信息需要从细胞状态中遗忘。
除了遗忘门,LSTM还包括“输入门”it和“候选门”:
有了这些门之后,就可以引入LSTM的核心概念——细胞状态Ct:
每个时刻的细胞状态Ct分为两部分:一部分是有选择地保留了上一次细胞状态值Ct-1,另一部分来自本次的输入。更新完Ct后,通过输出门ot得到此时的输出ht。
LSTM的这种设计类似于中国古代的“万年历”,通过记录历史信息并结合当前输入,可以预测未来的发展趋势。在《三体》中,墨子通过长期观察和记录三个太阳的运动轨迹,总结出一套“万年历”,可以预测“恒纪元”与“乱纪元”的更迭。
具体来看LSTM的前向传播过程:
- 输出ht代表需要预测的信息,由输出门和Ct运算得到
- Ct包含历史信息Ct-1和本次输入
- 遗忘门ft决定如何取舍上一次的细胞状态信息
- 输入门it选择性提取输入信息的主要特征
以一个具体的例子说明LSTM的运行过程:
import numpy as np
import math
import torch
from torch import nn, optim
import matplotlib.pyplot as plt
modelname = 'model/lstmfunction'
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
datainterval=10
def load(num):
x=range(num)
y=[float(3*math.sin(t)+ math.cos(5*t)+np.random.randn(1,1)+6) for t in x]
return np.array(y,dtype=np.float32)
class lstm(nn.Module):
def __init__(self, input_size , hidden_size , output_size , num_layer,bidirect =1,dropout=0.1 ):
super(lstm, self).__init__()
self.hiddensize= hidden_size
self.bidirectional=bidirect
self.layer1 = nn.LSTM(input_size, hidden_size, num_layer ,bidirectional=False if bidirect==1 else True )
self.Dropped = nn.Dropout(dropout)
self.Tanh = nn.Tanh()
self.layer2 = nn.Linear(hidden_size*bidirect, output_size)
self.BN=nn.BatchNorm1d(datainterval)
def forward(self, x):
input=self.BN(x)
input = input.view(-1, datainterval, 1)
x = torch.transpose(input, 0, 1)
out,(hidden,_) = self.layer1(x)
out = out[-1, :, :]
out=out.view(-1, self.hiddensize*self.bidirectional)
out= self.Dropped (out)
out = self.layer2(out)
return out
def create_dataset(num,trainfactor, interval ):
dataset = load(num)
dataX, dataY = [], []
for i in range(len(dataset) - interval):
a = dataset[i:(i + interval)]
dataX.append(a)
dataY.append(dataset[i + interval])
X,Y=np.array(dataX), np.array(dataY)
train_size = int( X.shape[0] * trainfactor)
train_X, train_Y, test_X, test_Y = X[:train_size], Y[:train_size], X[train_size:], Y[train_size:]
train_X = train_X.reshape(-1, datainterval )
train_Y = train_Y.reshape(-1, 1 )
test_X = test_X.reshape(-1, datainterval )
test_Y = test_Y.reshape(-1, 1 )
train_x =torch.from_numpy(train_X)
train_y = torch.from_numpy(train_Y)
test_x = torch.from_numpy(test_X)
test_y = torch.from_numpy(test_Y)
return train_x, train_y, test_x, test_y
def train(train_x, train_y):
model = lstm(1, 200, 1, 2)
model.train()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
iternum=5000
for e in range(iternum):
out = model(train_x)
loss = criterion(out, train_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (e + 1) % 1000 == 0:
torch.save(model, modelname )
if (e + 1) % 1000 == 0:
print('迭代数: {}, 损失值: {:.6f}'.format(e + 1, loss.data.item()/train_x.shape[0]) )
print('\r')
def test(num,trainfactor=0.8):
model = torch.load(modelname)
model.eval()
dataset = load(num)
train_size = int(dataset.__len__() * trainfactor)
dataset=dataset[train_size:]
startnum=datainterval
predictnum = 50
x=range(startnum,startnum+predictnum)
y=dataset[startnum:startnum+predictnum]
y_=[]
data=[x for x in dataset[:datainterval]]
for i in range(0,predictnum):
seed = np.array(data[i:i+datainterval],dtype=np.float32)
seed=seed.reshape(-1, datainterval )
p=model(torch.from_numpy(seed))
y_.append(p.item())
data.append(p.item())
plt.plot(x, y, color='blue')
plt.plot(x, y_, color='red')
plt.grid()
plt.show()
if __name__=='__main__':
runtest = True
if (runtest):
test(2000)
else:
train_X, train_Y, test_X, test_Y = create_dataset(500, 0.8, datainterval)
train(train_X, train_Y)
测试模型效果如下图所示:
蓝色线是目标函数走势,红色线是LSTM的预测走势图,模型基本上模拟出了实际数据走势。如果使用GPU并增加迭代次数,模型效果会更好。
二、LSTM反向传播推导
LSTM的反向传播过程相对复杂,但可以通过链式求导法则逐步推导。首先列出LSTM前向传播中的公式组,包括四个门以及两个输出:
为了方便推导,定义遗忘门ft在t时刻的输入:
可以将权重矩阵Wf拆解为Wfh和Wfx两个矩阵,分别与ht-1和xt相乘。类似地,可以定义其他门的输入公式。
假设已知t时刻的误差δt,由于LSTM的输出ht没有使用激活函数,δt定义为:
根据链式求导法则,上一时刻的误差δt-1为:
求δt-1的核心是求出:
观察前向传播的公式组,ht等式右侧的ot和Ct都含有ht-1,而Ct中ft、it、都含有ht-1。通过逐项使用链式求导法则,可以得到完整的反向传播过程。