问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

nn.RNN的输入输出及其内部结构说明

创作时间:
作者:
@小白创作中心

nn.RNN的输入输出及其内部结构说明

引用
CSDN
1.
https://m.blog.csdn.net/m0_62965652/article/details/138140007

本文详细介绍了PyTorch中nn.RNN模块的输入输出格式及其内部结构。通过对比embedding和RNN的输入格式,文章深入解释了RNN的输入输出维度,并通过代码示例展示了如何处理RNN的输出。此外,还介绍了RNN_cell的内部结构和计算过程。

1. input

nn.RNN的输入要求为(seq_len, batch_size, input_size),即(T, N, E)格式。其中:

  • seq_len是序列长度
  • batch_size是批大小
  • input_size是输入的特征维度

与embedding的输出格式(N, T, E)不同,这是因为RNN的自身结构决定的。在RNN中,每个时间步都会完成一次神经网络的功能,而embedding则是在每个时间步中向量化一个单词。


图片来自b站耿直哥

在处理数据时,需要注意矩阵的转换。例如:

x = self.embedding_layer(x)  # [N,T] -> [N,T,E]
ho, hn = self.rnn(x)  # ho [N,T,hidden_size] hn [?,N,E]
hz = torch.permute(hn, dims=[1, 0, 2])  # [?,N,E] -> [N,?,E]

在seq2seq模型中,为了将信息压缩在一个矩阵中,通常还会执行以下步骤:

hz = torch.reshape(hz, shape=(hz.shape[0], self.output_dim))  # [N,?,E] -> [N,?*E]

2. output

nn.RNN会有两个输出,分别是ho(output)hn(hidden)

2.1 h_o(output)

h_o会输出RNN在所有时间步上的隐藏状态输出,包含了整个序列在每个时间步的隐藏状态。其输出格式为:

(seq_len, batch_size, num_directions * hidden_size),即(T, N, E*(1or2))

其中:

  • seq_len是输入序列的长度
  • batch_size是批大小
  • num_directions是方向数,单向为1,双向为2
  • hidden_size是隐藏状态的维度

为什么最后一个是hidden_size而不是神经网络的output_size呢?原因是h_n只保留了最后一步的hidden_state,但中间的hidden_state也有可能会参与计算,所以pytorch把中间每一步输出的hidden_state都放到output中(当然,只保留了hidden_state最后一层的输出),因此,你可以发现这个output的维度是(seq_len, batch, num_directions * hidden_size)

2.2 h_n(hidden)

h_n的输出格式为:

(num_layers * num_directions, batch_size, hidden_size),即(num_layers*(1or2), N, E)

如果没有提供,默认为全0。其中:

  • num_layers是RNN的层数
  • num_directions是方向数,如果是单向RNN则为1,如果是双向RNN则为2
  • hidden_size是隐藏状态的维度

为什么hidden_size的格式与inputoutput不同,变成了(num_layers, N, E)的形式呢?简单来说就是因为隐藏状态不是一个时间序列,而是在每一层中都持有一个向量。而输出中间状态就是为了得到每个时刻的隐层输出。所以num_layers * num_directions这个维度代替了seq_len

初始化rnn

rnn = nn.RNN(input_size, hidden_size, num_layers)

无需多言看图即懂:

  • 其中Xn是input_size
  • A(第一层), A'(第二层), A''(第三层) 则是num_layers
  • 在每个A中都是一个RNN_cell,每个都是一个全连接网络,而hidden_size类似于全连接网络中的隐藏层。

RNN_cell

其内部结构如下:


没想到吧依旧是这张图。其实所谓的RNN_cell就是一个全连接神经网络。

内部的计算过程:

  1. 输入:
  • input: 当前时间步的输入,形状为(batch_size, input_size)(T, E)
  • hidden: 前一时间步的隐藏状态,形状为(batch_size, hidden_size)
  1. 前向计算:
  • 将输入input和前一隐藏状态hidden进行线性变换:
    gate = F.linear(input, self.weight_ih, self.bias_ih) + \
           F.linear(hidden, self.weight_hh, self.bias_hh)
    
  • 将线性变换的结果gate应用激活函数(如tanh)得到新的隐藏状态new_hidden:
    new_hidden = F.tanh(gate)
    
  1. 输出:
  • new_hidden: 当前时间步的新隐藏状态,形状为(batch_size, hidden_size)(N, E)

这个new_hidden里的hidden_size就是前面inputoutputh_0hidden_size啦。对于RNN整个网络来说,这个new_hidden是RNN_cell的输出,就是隐层的输出。但是对于RNN_cell来说,则是经过完整的全连接网络并且激活过的output!

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号