问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

PyTorch从零实现LSTM模型:最简洁的非封装版本

创作时间:
作者:
@小白创作中心

PyTorch从零实现LSTM模型:最简洁的非封装版本

引用
CSDN
1.
https://blog.csdn.net/qq_64809150/article/details/140432814

LSTM(长短期记忆网络)是一种特殊的循环神经网络(RNN),能够学习长期依赖关系,广泛应用于序列数据处理任务,如时间序列预测、自然语言处理等。本文将介绍如何使用PyTorch从头搭建一个LSTM模型,不使用任何第三方封装,帮助读者深入理解LSTM的工作原理。

前不久在项目中使用到了LSTM模型,为了更好地理解其工作原理,我根据Colah的博客《Understanding LSTM Networks》从零开始实现了一个LSTM模型。本文适合已经对LSTM模型有一定了解的读者,建议先阅读《Understanding LSTM Networks》一文,因为本文的代码实现完全基于该文章。

接下来是LSTM模型的完整代码实现:

import torch
import torch.nn as nn

class RiceLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RiceLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.Wf = nn.Linear(input_size + hidden_size, hidden_size)
        self.Wi = nn.Linear(input_size + hidden_size, hidden_size)
        self.Wo = nn.Linear(input_size + hidden_size, hidden_size)
        self.Wc = nn.Linear(input_size + hidden_size, hidden_size)
        self.output_layer = nn.Linear(hidden_size, output_size)
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()

    def forward(self, input):
        batch_size = input.size(0)
        seq_len = input.size(1)
        hidden_state = torch.zeros(batch_size, self.hidden_size, dtype=torch.float32)
        cell_state = torch.zeros(batch_size, self.hidden_size, dtype=torch.float32)
        outputs = []
        for i in range(seq_len):
            combined = torch.cat((input[:, i, :], hidden_state), dim=1)
            f_t = self.sigmoid(self.Wf(combined))
            i_t = self.sigmoid(self.Wi(combined))
            o_t = self.sigmoid(self.Wo(combined))
            c_hat_t = self.tanh(self.Wc(combined))
            cell_state = f_t * cell_state + i_t * c_hat_t
            hidden_state = o_t * self.tanh(cell_state)
            outputs.append(hidden_state.unsqueeze(1))
        outputs = torch.cat(outputs, dim=1)
        final_output = self.output_layer(outputs)
        return final_output, (hidden_state, cell_state)

上述模型是一个最简单的LSTM模型,代码看上去可能有些复杂,但通过仔细阅读注释和理解LSTM的工作原理,可以逐步掌握其实现细节。模型的输入特征向量是影响因素,这里以水稻产量预测为例,使用温度和湿度作为特征向量进行测试。每个特征向量随机产生100个数据,按照80%训练集和20%测试集的比例划分数据集,因此在代码中可以看到维度为80的输出。

模型的效果展示如下(注意:由于数据集较小,仅包含150条记录,且特征向量较多,因此预测效果并不理想,这是神经网络的常见问题。如果减少特征向量的数量,即使数据集较小,也能获得更准确的预测结果):



本文代码和模型实现均参考自《Understanding LSTM Networks》,强烈建议读者阅读该文章,因为几乎所有关于LSTM的经典图片都出自这里。

文章来源:CSDN

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号