nn.RNN的输入输出及其内部结构说明
nn.RNN的输入输出及其内部结构说明
本文详细介绍了PyTorch中nn.RNN模块的输入输出格式及其内部结构。通过对比embedding和RNN的输入格式,文章深入解释了RNN的输入输出维度,并通过代码示例展示了如何处理RNN的输出。此外,还介绍了RNN_cell的内部结构和计算过程。
1. input
nn.RNN的输入要求为(seq_len, batch_size, input_size)
,即(T, N, E)
格式。其中:
seq_len
是序列长度batch_size
是批大小input_size
是输入的特征维度
与embedding的输出格式(N, T, E)
不同,这是因为RNN的自身结构决定的。在RNN中,每个时间步都会完成一次神经网络的功能,而embedding则是在每个时间步中向量化一个单词。
图片来自b站耿直哥
在处理数据时,需要注意矩阵的转换。例如:
x = self.embedding_layer(x) # [N,T] -> [N,T,E]
ho, hn = self.rnn(x) # ho [N,T,hidden_size] hn [?,N,E]
hz = torch.permute(hn, dims=[1, 0, 2]) # [?,N,E] -> [N,?,E]
在seq2seq模型中,为了将信息压缩在一个矩阵中,通常还会执行以下步骤:
hz = torch.reshape(hz, shape=(hz.shape[0], self.output_dim)) # [N,?,E] -> [N,?*E]
2. output
nn.RNN会有两个输出,分别是ho(output)
和hn(hidden)
。
2.1 h_o(output)
h_o
会输出RNN在所有时间步上的隐藏状态输出,包含了整个序列在每个时间步的隐藏状态。其输出格式为:
(seq_len, batch_size, num_directions * hidden_size)
,即(T, N, E*(1or2))
。
其中:
seq_len
是输入序列的长度batch_size
是批大小num_directions
是方向数,单向为1,双向为2hidden_size
是隐藏状态的维度
为什么最后一个是hidden_size
而不是神经网络的output_size
呢?原因是h_n
只保留了最后一步的hidden_state,但中间的hidden_state也有可能会参与计算,所以pytorch把中间每一步输出的hidden_state都放到output
中(当然,只保留了hidden_state最后一层的输出),因此,你可以发现这个output
的维度是(seq_len, batch, num_directions * hidden_size)
。
2.2 h_n(hidden)
h_n
的输出格式为:
(num_layers * num_directions, batch_size, hidden_size)
,即(num_layers*(1or2), N, E)
。
如果没有提供,默认为全0。其中:
num_layers
是RNN的层数num_directions
是方向数,如果是单向RNN则为1,如果是双向RNN则为2hidden_size
是隐藏状态的维度
为什么hidden_size
的格式与input
和output
不同,变成了(num_layers, N, E)
的形式呢?简单来说就是因为隐藏状态不是一个时间序列,而是在每一层中都持有一个向量。而输出中间状态就是为了得到每个时刻的隐层输出。所以num_layers * num_directions
这个维度代替了seq_len
。
初始化rnn
rnn = nn.RNN(input_size, hidden_size, num_layers)
无需多言看图即懂:
- 其中Xn是
input_size
, - A(第一层), A'(第二层), A''(第三层) 则是
num_layers
, - 在每个A中都是一个RNN_cell,每个都是一个全连接网络,而
hidden_size
类似于全连接网络中的隐藏层。
RNN_cell
其内部结构如下:
没想到吧依旧是这张图。其实所谓的RNN_cell就是一个全连接神经网络。
内部的计算过程:
- 输入:
input
: 当前时间步的输入,形状为(batch_size, input_size)
即(T, E)
hidden
: 前一时间步的隐藏状态,形状为(batch_size, hidden_size)
- 前向计算:
- 将输入
input
和前一隐藏状态hidden
进行线性变换:gate = F.linear(input, self.weight_ih, self.bias_ih) + \ F.linear(hidden, self.weight_hh, self.bias_hh)
- 将线性变换的结果
gate
应用激活函数(如tanh)得到新的隐藏状态new_hidden
:new_hidden = F.tanh(gate)
- 输出:
new_hidden
: 当前时间步的新隐藏状态,形状为(batch_size, hidden_size)
即(N, E)
这个new_hidden
里的hidden_size
就是前面input
、output
、h_0
的hidden_size
啦。对于RNN整个网络来说,这个new_hidden
是RNN_cell的输出,就是隐层的输出。但是对于RNN_cell来说,则是经过完整的全连接网络并且激活过的output!