问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

必看!一文看懂长短期记忆网络(LSTM)

创作时间:
作者:
@小白创作中心

必看!一文看懂长短期记忆网络(LSTM)

引用
CSDN
1.
https://m.blog.csdn.net/Code1994/article/details/145631949

循环神经网络(RNN)是处理序列数据的得力工具。然而,普通RNN在处理长序列时存在梯度消失或梯度爆炸的问题,长短期记忆网络(LSTM)应运而生,它有效解决了这些问题,在诸多领域大放异彩。今天,就让我们一同深入了解LSTM。

长短期记忆网络(通常简称为“LSTM”)是一种特殊的循环神经网络(RNN),能够学习长期依赖关系。这种网络最早是由 Hochreiter 和 Schmidhuber 在1997年提出的,后来经过很多人进一步完善和推广。LSTM在各种问题上表现得非常出色,现在已经被广泛应用了。

LSTM的设计初衷就是为了避免长期依赖问题。它们天生就擅长记住长时间的信息,这并不是它们需要努力去学习的技能,而是它们的“本能”。

所有循环神经网络都有一个链式结构,由重复的神经网络模块组成。在普通的RNN中,这个重复模块的结构非常简单,比如只有一个tanh层。

LSTM也有这种链式结构,但它的重复模块和普通RNN不一样。它不是只有一个神经网络层,而是有四个,而且这四个层之间的互动方式特别独特。

别担心这些细节听起来有点复杂。我们下面会一步步详细讲解LSTM的图解。

一、基本架构和原理

LSTM的基本架构由一个细胞状态(Cell State)和三个门控结构组成,这三个门控分别是输入门(Input Gate)、遗忘门(Forget Gate)和输出门(Output Gate)。

LSTM的关键在于细胞状态(cell state),就是图上那条横着穿过顶部的线。

细胞状态就像是LSTM的“记忆高速公路”,它贯穿整个LSTM单元,能够在序列中传递长期信息。数据可以直接通过细胞状态在不同时间步之间传递,减少信息的丢失。

LSTM确实有能力从细胞状态中移除或者添加信息,但这个过程是由一种叫“门”的结构严格控制的。

简单来说,门的作用就是决定要不要让信息通过。它们是由一个S型神经网络层(sigmoid layer)和逐点乘法运算组成的。

S型神经网络层会输出介于0到1之间的数字,用来决定每部分信息该放多少过去。要是输出是0,那就是“啥都别放过去”;要是输出是1,那就是“全放过去”!

LSTM里有三个这样的门,专门用来保护和控制细胞状态。

输入门决定了当前输入的信息有多少将被存入细胞状态。它通过一个Sigmoid层和一个tanh层共同作用,Sigmoid层输出0到1之间的值,表示每个输入特征的重要性程度,tanh层则对输入数据进行变换,最终两者结果相乘,确定哪些新信息将被添加到细胞状态中。

遗忘门控制着细胞状态中哪些信息需要被遗忘。同样由Sigmoid层实现,输出值在0到1之间,0表示完全遗忘,1表示完全保留,通过与上一时刻的细胞状态相乘,决定保留或丢弃哪些历史信息。

输出门决定了细胞状态中的哪些信息将被输出。它先利用Sigmoid层确定细胞状态中哪些部分将被输出,再通过tanh层对细胞状态进行变换,最后两者结果相乘,得到最终的输出。

二、LSTM的逐步解析

LSTM的第一步是要决定从细胞状态里扔掉哪些信息。

1.遗忘门计算

这个决定是由一个叫“遗忘门层”的S型神经网络层来做的。它会看看上一个时间步的隐藏状态 和当前的输入 ,然后对细胞状态 里的每个数字输出一个0到1之间的值。1的意思是“完全保留这个”,0的意思是“完全扔掉这个”。

  • 计算遗忘门的激活值:
  • 对通过权重矩阵和偏置进行线性变换,再经过Sigmoid函数,。

1.输入门计算

下一步就是决定我们要把哪些新信息存到细胞状态里。这事儿分两步走。首先,一个叫“输入门层”的S型神经网络层会决定我们要更新哪些值。然后,一个tanh层会生成一组新的候选值 ,这些值可能会被加到状态里。在下一步,我们会把这两部分结合起来,更新细胞状态。

  • 计算输入门的激活值:
  • 首先,将当前输入和上一时刻的隐藏状态进行拼接,得到。
  • 然后,通过权重矩阵和偏置进行线性变换,再经过Sigmoid函数,即。
  • 计算候选细胞状态:
  • 同样对进行线性变换,这次使用权重矩阵和偏置,再经过tanh函数,得到。

3.更新细胞状态

现在该把旧的细胞状态 变成新的细胞状态 。

之前的步骤都已经定好了该咋弄,咱就按那法子来就行啦。咱把旧状态乘以 ,这样就能把之前打算忘掉的信息给扔掉啦。然后呢,再加上 ,这就是新的候选值啦,这个候选值是根据我们之前决定的更新每个状态值的量来调整的哦。要是在语言模型里呢,这就是我们把旧主语的性别信息去掉,再添上些新信息的地方啦,这些事儿在之前的步骤里都已经定好咯。

  • 计算当前时刻的细胞状态:
  • 用上一时刻的细胞状态乘以遗忘门,表示遗忘部分历史信息,再加上输入门与候选细胞状态的乘积,即。

4.输出门计算

最后,我们得决定要输出啥。这个输出是基于我们的细胞状态的,但会是一个经过筛选的版本哦。首先呢,我们运行一个 S 型函数层,它会决定我们要输出细胞状态的哪些部分。然后呢,我们把细胞状态通过 tanh 函数(这样可以把值推到 -1 到 1 之间),再把它和 S 型函数门的输出相乘,这样我们就只输出我们决定输出的部分啦。

  • 计算输出门的激活值:
  • 对通过权重矩阵和偏置进行线性变换,再经过Sigmoid函数,。
  • 计算当前时刻的隐藏状态:
  • 先对细胞状态通过tanh函数进行变换,再乘以输出门,即。

三、优缺点

优点

  1. 长期记忆能力:LSTM能够有效地处理长序列数据中的长期依赖关系,通过门控机制选择性地记忆和遗忘信息,避免了梯度消失或梯度爆炸问题,这使得它在处理语音识别、自然语言处理等需要长期记忆的任务中表现出色。

  2. 适应性强:可以处理不同长度的序列数据,在各种领域都有广泛的应用,无论是文本、语音还是时间序列数据等。

  3. 鲁棒性好:相比普通RNN,LSTM对噪声和异常值具有更好的鲁棒性,能够在数据质量不太理想的情况下依然保持较好的性能。

缺点

  1. 计算复杂度高:由于LSTM包含多个门控结构和复杂的计算过程,其计算量较大,训练时间较长。在处理大规模数据和实时性要求较高的任务时,可能会面临一定的挑战。

  2. 调参困难:LSTM有较多的超参数,如隐藏层大小、学习率、权重初始化等,这些超参数的设置对模型性能有较大影响,需要花费大量时间和精力进行调参。

  3. 过拟合风险:如果模型结构过于复杂,而训练数据量相对较少时,LSTM可能会出现过拟合现象,导致模型在测试集上的性能下降。

四、相关案例代码

下面我们用Python和Keras库来实现一个简单的LSTM模型,用于预测正弦波时间序列。

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# 生成正弦波数据
time_steps = 50
data = []
for i in range(1000):
    x = np.sin(np.linspace(i, i + time_steps, time_steps))
    data.append(x)
data = np.array(data)
data = np.reshape(data, (data.shape[0], data.shape[1], 1))
data = torch.FloatTensor(data)

# 划分训练集和测试集
train_size = int(len(data) * 0.8)
train_data = data[:train_size]
test_data = data[train_size:]

class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.fc(out[:, -1, :])
        return out

input_size = 1
hidden_size = 50
output_size = 1
model = LSTMModel(input_size, hidden_size, output_size)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 50
for epoch in range(num_epochs):
    outputs = model(train_data)
    loss = criterion(outputs, train_data[:, -1, :])
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

# 预测
with torch.no_grad():
    train_predict = model(train_data)
    test_predict = model(test_data)
train_predict = train_predict.numpy()
test_predict = test_predict.numpy()
train_data = train_data.numpy()
test_data = test_data.numpy()

# 绘制结果
plt.figure(figsize=(12, 6))
plt.plot(np.arange(train_size), train_data[:, -1, 0], label='Train Data')
plt.plot(np.arange(train_size, train_size + len(test_data)), test_data[:, -1, 0], label='Test Data')
plt.plot(np.arange(train_size), train_predict[:, 0], label='Train Prediction', linestyle='--')
plt.plot(np.arange(train_size, train_size + len(test_data)), test_predict[:, 0], label='Test Prediction', linestyle='--')
plt.legend()
plt.show()

在这个例子中,我们首先生成了正弦波时间序列数据,然后将其划分为训练集和测试集。接着构建了一个简单的LSTM模型,包含一个LSTM层和一个全连接层。通过编译和训练模型,最后对训练集和测试集进行预测,并将结果可视化。

长短期记忆网络(LSTM)凭借其独特的架构和强大的功能,在人工智能领域发挥着重要作用。希望通过本文的介绍,能让大家对LSTM有更深入的理解,激发大家在相关领域的探索和实践。

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号