问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

LSTM模型内部结构详解

创作时间:
作者:
@小白创作中心

LSTM模型内部结构详解

引用
CSDN
1.
https://blog.csdn.net/qq_44833392/article/details/121840785

LSTM(Long Short-Term Memory)模型是循环神经网络(RNN)的一种变体,特别适用于处理和预测时间序列数据。与传统的RNN相比,LSTM能够更好地捕捉长序列之间的依赖关系,有效缓解了梯度消失或爆炸的问题。本文将深入解析LSTM模型的核心结构及其工作原理。

LSTM模型

LSTM(Long Short-Term Memory)也称长短时记忆结构, 它是传统RNN的变体, 与经典RNN相比能够有效捕捉长序列之间的语义关联, 缓解梯度消失或爆炸现象

LSTM核心结构

  • 遗忘门
  • 输入门
  • 细胞状态
  • 输出门

LSTM的内部结构图

遗忘门

遗忘门部分结构图与计算公式

遗忘门结构分析

与传统RNN的内部结构计算非常相似, 首先将当前时间步输入x(t)与上一个时间步隐含状态h(t-1)拼接, 得到[x(t), h(t-1)], 然后通过一个全连接层做变换, 最后通过sigmoid函数进行激活得到f(t), 我们可以将f(t)看作是门值, 好比一扇门开合的大小程度, 门值都将作用在通过该扇门的张量。

遗忘门门值将作用的上一层的细胞状态上, 代表遗忘过去的多少信息, 又因为遗忘门门值是由x(t), h(t-1)计算得来的, 因此整个公式意味着根据当前时间步输入和上一个时间步隐含状态h(t-1)来决定遗忘多少上一层的细胞状态所携带的过往信息.

遗忘门内部结构过程演示

激活函数sigmiod的作用

  • 用于帮助调节流经网络的值, sigmoid函数将值压缩在0和1之间.

输入门

输入门部分结构图与计算公式

输入门结构分析

我们看到输入门的计算公式有两个, 第一个就是产生输入门门值的公式, 它和遗忘门公式几乎相同, 区别只是在于它们之后要作用的目标上. 这个公式意味着输入信息有多少需要进行过滤. 输入门的第二个公式是与传统RNN的内部结构计算相同. 对于LSTM来讲, 它得到的是当前的细胞状态, 而不是像经典RNN一样得到的是隐含状态.

输入门内部结构过程演示

细胞更新状态

细胞状态更新图与计算公式

细胞状态更新分析

细胞更新的结构与计算公式非常容易理解, 这里没有全连接层, 只是将刚刚得到的遗忘门门值与上一个时间步得到的C(t-1)相乘, 再加上输入门门值与当前时间步得到的未更新C(t)相乘的结果. 最终得到更新后的C(t)作为下一个时间步输入的一部分. 整个细胞状态更新过程就是对遗忘门和输入门的应用.

细胞状态更新过程演示

输出门

输出门部分结构图与计算公式

输出门结构分析

输出门部分的公式也是两个, 第一个即是计算输出门的门值, 它和遗忘门,输入门计算方式相同. 第二个即是使用这个门值产生隐含状态h(t), 他将作用在更新后的细胞状态C(t)上, 并做tanh激活, 最终得到h(t)作为下一时间步输入的一部分. 整个输出门的过程, 就是为了产生隐含状态h(t).

输出门内部结构过程演示

举例

'''
Description: lstm举例
Autor: 365JHWZGo
Date: 2021-12-09 19:20:23
LastEditors: 365JHWZGo
LastEditTime: 2021-12-09 19:28:08
'''
import torch
import torch.nn as nn
torch.manual_seed(1)
TIME_STEP = 1
INPUT_SIZE = 5
HIDDEN_LAYER = 2
HIDDEN_SIZE = 6
BATCH_SIZE = 3
input_data = torch.randn(TIME_STEP,BATCH_SIZE,INPUT_SIZE)
h0 = torch.randn(HIDDEN_LAYER,BATCH_SIZE,HIDDEN_SIZE)
c0 = torch.randn(HIDDEN_LAYER,BATCH_SIZE,HIDDEN_SIZE)
lstm = nn.LSTM(INPUT_SIZE,HIDDEN_SIZE,HIDDEN_LAYER)
output,(h_,c_) = lstm(input_data,(h0,c0))
print(f'output:\t{output}')
print(f'h_:\t{h_}')
print(f'c_:\t{c_}')
© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号