问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

Transformer系列:残差连接原理详细解析和代码论证

创作时间:
作者:
@小白创作中心

Transformer系列:残差连接原理详细解析和代码论证

引用
51CTO
1.
https://blog.51cto.com/u_16163453/12406781

残差连接(Residual Connection)是深度学习中解决深层网络训练问题的关键技术之一。本文将从历史背景、原理分析、代码实现和实验验证等多个维度,深入探讨残差连接在Transformer中的应用及其重要作用。

残差连接的历史由来

残差连接最早可以追溯到2015年何凯明等人提出的ResNet。这种结构的引入解决了神经网络随着层数增多而难以训练的问题,主要缓解了梯度消失、梯度爆炸和网络退化等问题,使得网络可以拓展到更深的层次。

Transformer中的残差连接

Transformer也采用了残差连接(residual connection)这一标准结构。在Transformer的Encoder和Decoder中,层与层之间通过"ADD & Norm"的方式进行连接。

具体来说,"Add"操作是指将本层的输出和本层的输入在对应位置相加,且要求两者的维度相同。以下是一个Encoder中的实现示例:

output, slf_attn = self.self_att_layer(enc_input, enc_input, enc_input, mask=mask)
output = self.norm_layer(Add()([enc_input, output]))

其中,self_att_layer是Encoder中的多头注意力机制,enc_input是多头注意力的输入,而output是多头注意力的输出。通过Keras的Add()算子将enc_inputoutput相加,得到最终的输出。

深层网络的问题代码复现

为了验证深层网络带来的训练问题,我们构建了一个可以传入层数参数的Dense网络模型。通过实验测试不同层数(3、10、20、35、50、100)下的模型训练效果。

class Model(object):
    def __init__(self, num_class, feature_size, layer_num=100, learning_rate=0.001, weight_decay=0.01, decay_learning_rate=1):
        self.input_x = tf.placeholder(tf.float32, [None, feature_size], name="input_x")
        self.input_y = tf.placeholder(tf.float32, [None, num_class], name="input_y")
        self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob")
        self.batch_normalization = tf.placeholder(tf.bool, name="batch_normalization")
        self.global_step = tf.Variable(0, name="global_step", trainable=False)
        tmp_tensor = self.input_x
        for i in range(layer_num):
            with tf.variable_scope('layer_{}'.format(i + 1)):
                dense_out_1 = tf.layers.dense(tmp_tensor, 32)
                dense_out_1 = batch_norm_layer(dense_out_1, is_training=self.batch_normalization, scope="bn{}".format(i + 1))
                tmp_tensor = tf.nn.relu(dense_out_1)
        with tf.variable_scope('layer_out'):
            self.output = tf.layers.dense(tmp_tensor, 2)
            self.probs = tf.nn.softmax(self.output, dim=1, name="probs")
        with tf.variable_scope('loss'):
            self.loss = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits_v2(logits=self.output, labels=self.input_y))
            vars = tf.trainable_variables()
            loss_l2 = tf.add_n([tf.nn.l2_loss(v) for v in vars if
                                v.name not in ['bias', 'gamma', 'b', 'g', 'beta']]) * weight_decay
            self.loss += loss_l2
        with tf.variable_scope("optimizer"):
            if decay_learning_rate:
                learning_rate = tf.train.exponential_decay(learning_rate, self.global_step, 100, decay_learning_rate)
            optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                self.train_step = optimizer.minimize(self.loss, global_step=self.global_step)
        with tf.variable_scope("metrics"):
            self.accuracy = tf.reduce_mean(
                tf.cast(tf.equal(tf.arg_max(self.probs, 1), tf.arg_max(self.input_y, 1)), dtype=tf.float32))

实验结果显示,随着网络层数的增加,模型在训练集和测试集上的准确率都在下降。当网络深度达到35层以上时,模型几乎无法收敛,准确率维持在50%左右。

深层网络的问题分析

深层网络面临的主要问题包括:

  • 梯度消失:在反向传播过程中,由于权重的累乘,当权重接近0时,梯度会迅速消失。
  • 梯度爆炸:与梯度消失相反,当权重较大时,累乘会导致梯度爆炸。
  • 网络退化:理论上存在一个最优网络深度,超过这个深度后,额外的网络层不仅不会提升性能,反而可能导致性能下降。

残差连接的作用通俗理解

残差连接的核心思想是在每一层或每隔几层引入一个"兜底策略"。其目的是确保即使模型已经达到最优深度,增加冗余层也不会导致性能下降。

具体来说,假设模型共有50层,如果在第16层时模型已经达到最佳效果,那么从第17层开始到第50层,模型将学习一种恒等变换,最终在最后一层将第16层的输出恒等映射出来。

残差连接的实现方式是将上一层的输出直接连接到下一层的输出,即上一层的输出与下一层的原始输出对应位置相加形成最终输出。

其中,X代表逐渐逼近最优结果的上层输出,而F(x)代表残差,表示还可以进一步优化的部分。当模型深度达到最优值时,残差连接可以自适应地将F(x)学习为全0,从而实现恒等变换。

残差连接和GBDT类比

残差连接的思想与GBDT(Gradient Boosting Decision Tree)有相似之处。GBDT通过多个基学习器拟合之前的残差,而残差连接则以X为基线,F(x)聚焦于可进一步优化的部分。不同的是,GBDT的每个基分类器结果直接相加得到预测结果,而残差连接在网络中间层,最后还需要通过全连接层进行任务分类。

残差连接的作用公式理解

在普通网络中,从x到y4的计算过程如下:

y1 = σ(w1 * x + b1)
y2 = σ(w2 * y1 + b2)
y3 = σ(w3 * y2 + b3)
y4 = σ(w4 * y3 + b4)

加入残差连接后,计算过程变为:

y1 = σ(w1 * x + b1) + x
y2 = σ(w2 * y1 + b2) + y1
y3 = σ(w3 * y2 + b3) + y2
y4 = σ(w4 * y3 + b4) + y3

即:

其中,X代表某层的输出,某个高层I的输出等于某个低层i的输入加上两层之间所有残差F的结果。此时,对低层i求梯度的结果如下:

括号展开的第一项直接就是高层I的梯度,作为因子直接作用到低层i的梯度,而不是像普通网络那样经过各种累乘放大或缩小。等式右侧是一个累加,相比于原来的累乘,一定程度上降低了梯度爆炸和弥散的概率。

深层网络运用残差连接代码实践

在之前的深层网络代码中加入残差连接:

for i in range(layer_num):
    with tf.variable_scope('layer_{}'.format(i + 1)):
        dense_out_1 = tf.layers.dense(tmp_tensor, 32)
        dense_out_1 = batch_norm_layer(dense_out_1, is_training=self.batch_normalization, scope="bn{}".format(i + 1))
        if i != 0:
            # 残差连接
            dense_out_1 = tf.nn.relu(dense_out_1)
            dense_out_1 = tf.add(dense_out_1, tmp_tensor)
        # bn
        tmp_tensor = tf.nn.relu(dense_out_1)

实验结果显示,加入残差连接后,不同层数的模型都能较好地收敛。虽然35层和100层的性能略低于其他层,但差距并不大,整体准确率保持在0.7左右,相比普通网络有显著提升。

总结

残差连接是深度学习中解决深层网络训练问题的关键技术。通过理论分析、代码实现和实验验证,我们可以看到残差连接如何有效缓解梯度消失、梯度爆炸和网络退化等问题,使得深层网络的训练成为可能。这一技术在Transformer等现代深度学习模型中得到了广泛应用,对于推动深度学习的发展具有重要意义。

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号