Transformer系列:残差连接原理详细解析和代码论证
Transformer系列:残差连接原理详细解析和代码论证
残差连接(Residual Connection)是深度学习中解决深层网络训练问题的关键技术之一。本文将从历史背景、原理分析、代码实现和实验验证等多个维度,深入探讨残差连接在Transformer中的应用及其重要作用。
残差连接的历史由来
残差连接最早可以追溯到2015年何凯明等人提出的ResNet。这种结构的引入解决了神经网络随着层数增多而难以训练的问题,主要缓解了梯度消失、梯度爆炸和网络退化等问题,使得网络可以拓展到更深的层次。
Transformer中的残差连接
Transformer也采用了残差连接(residual connection)这一标准结构。在Transformer的Encoder和Decoder中,层与层之间通过"ADD & Norm"的方式进行连接。
具体来说,"Add"操作是指将本层的输出和本层的输入在对应位置相加,且要求两者的维度相同。以下是一个Encoder中的实现示例:
output, slf_attn = self.self_att_layer(enc_input, enc_input, enc_input, mask=mask)
output = self.norm_layer(Add()([enc_input, output]))
其中,self_att_layer
是Encoder中的多头注意力机制,enc_input
是多头注意力的输入,而output
是多头注意力的输出。通过Keras的Add()
算子将enc_input
和output
相加,得到最终的输出。
深层网络的问题代码复现
为了验证深层网络带来的训练问题,我们构建了一个可以传入层数参数的Dense网络模型。通过实验测试不同层数(3、10、20、35、50、100)下的模型训练效果。
class Model(object):
def __init__(self, num_class, feature_size, layer_num=100, learning_rate=0.001, weight_decay=0.01, decay_learning_rate=1):
self.input_x = tf.placeholder(tf.float32, [None, feature_size], name="input_x")
self.input_y = tf.placeholder(tf.float32, [None, num_class], name="input_y")
self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob")
self.batch_normalization = tf.placeholder(tf.bool, name="batch_normalization")
self.global_step = tf.Variable(0, name="global_step", trainable=False)
tmp_tensor = self.input_x
for i in range(layer_num):
with tf.variable_scope('layer_{}'.format(i + 1)):
dense_out_1 = tf.layers.dense(tmp_tensor, 32)
dense_out_1 = batch_norm_layer(dense_out_1, is_training=self.batch_normalization, scope="bn{}".format(i + 1))
tmp_tensor = tf.nn.relu(dense_out_1)
with tf.variable_scope('layer_out'):
self.output = tf.layers.dense(tmp_tensor, 2)
self.probs = tf.nn.softmax(self.output, dim=1, name="probs")
with tf.variable_scope('loss'):
self.loss = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits_v2(logits=self.output, labels=self.input_y))
vars = tf.trainable_variables()
loss_l2 = tf.add_n([tf.nn.l2_loss(v) for v in vars if
v.name not in ['bias', 'gamma', 'b', 'g', 'beta']]) * weight_decay
self.loss += loss_l2
with tf.variable_scope("optimizer"):
if decay_learning_rate:
learning_rate = tf.train.exponential_decay(learning_rate, self.global_step, 100, decay_learning_rate)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
self.train_step = optimizer.minimize(self.loss, global_step=self.global_step)
with tf.variable_scope("metrics"):
self.accuracy = tf.reduce_mean(
tf.cast(tf.equal(tf.arg_max(self.probs, 1), tf.arg_max(self.input_y, 1)), dtype=tf.float32))
实验结果显示,随着网络层数的增加,模型在训练集和测试集上的准确率都在下降。当网络深度达到35层以上时,模型几乎无法收敛,准确率维持在50%左右。
深层网络的问题分析
深层网络面临的主要问题包括:
- 梯度消失:在反向传播过程中,由于权重的累乘,当权重接近0时,梯度会迅速消失。
- 梯度爆炸:与梯度消失相反,当权重较大时,累乘会导致梯度爆炸。
- 网络退化:理论上存在一个最优网络深度,超过这个深度后,额外的网络层不仅不会提升性能,反而可能导致性能下降。
残差连接的作用通俗理解
残差连接的核心思想是在每一层或每隔几层引入一个"兜底策略"。其目的是确保即使模型已经达到最优深度,增加冗余层也不会导致性能下降。
具体来说,假设模型共有50层,如果在第16层时模型已经达到最佳效果,那么从第17层开始到第50层,模型将学习一种恒等变换,最终在最后一层将第16层的输出恒等映射出来。
残差连接的实现方式是将上一层的输出直接连接到下一层的输出,即上一层的输出与下一层的原始输出对应位置相加形成最终输出。
其中,X代表逐渐逼近最优结果的上层输出,而F(x)代表残差,表示还可以进一步优化的部分。当模型深度达到最优值时,残差连接可以自适应地将F(x)学习为全0,从而实现恒等变换。
残差连接和GBDT类比
残差连接的思想与GBDT(Gradient Boosting Decision Tree)有相似之处。GBDT通过多个基学习器拟合之前的残差,而残差连接则以X为基线,F(x)聚焦于可进一步优化的部分。不同的是,GBDT的每个基分类器结果直接相加得到预测结果,而残差连接在网络中间层,最后还需要通过全连接层进行任务分类。
残差连接的作用公式理解
在普通网络中,从x到y4的计算过程如下:
y1 = σ(w1 * x + b1)
y2 = σ(w2 * y1 + b2)
y3 = σ(w3 * y2 + b3)
y4 = σ(w4 * y3 + b4)
加入残差连接后,计算过程变为:
y1 = σ(w1 * x + b1) + x
y2 = σ(w2 * y1 + b2) + y1
y3 = σ(w3 * y2 + b3) + y2
y4 = σ(w4 * y3 + b4) + y3
即:
其中,X代表某层的输出,某个高层I的输出等于某个低层i的输入加上两层之间所有残差F的结果。此时,对低层i求梯度的结果如下:
括号展开的第一项直接就是高层I的梯度,作为因子直接作用到低层i的梯度,而不是像普通网络那样经过各种累乘放大或缩小。等式右侧是一个累加,相比于原来的累乘,一定程度上降低了梯度爆炸和弥散的概率。
深层网络运用残差连接代码实践
在之前的深层网络代码中加入残差连接:
for i in range(layer_num):
with tf.variable_scope('layer_{}'.format(i + 1)):
dense_out_1 = tf.layers.dense(tmp_tensor, 32)
dense_out_1 = batch_norm_layer(dense_out_1, is_training=self.batch_normalization, scope="bn{}".format(i + 1))
if i != 0:
# 残差连接
dense_out_1 = tf.nn.relu(dense_out_1)
dense_out_1 = tf.add(dense_out_1, tmp_tensor)
# bn
tmp_tensor = tf.nn.relu(dense_out_1)
实验结果显示,加入残差连接后,不同层数的模型都能较好地收敛。虽然35层和100层的性能略低于其他层,但差距并不大,整体准确率保持在0.7左右,相比普通网络有显著提升。
总结
残差连接是深度学习中解决深层网络训练问题的关键技术。通过理论分析、代码实现和实验验证,我们可以看到残差连接如何有效缓解梯度消失、梯度爆炸和网络退化等问题,使得深层网络的训练成为可能。这一技术在Transformer等现代深度学习模型中得到了广泛应用,对于推动深度学习的发展具有重要意义。