Transformer系列:残差连接原理详细解析和代码论证
Transformer系列:残差连接原理详细解析和代码论证
Transformer中的残差连接(Residual Connection)是其核心组件之一,它借鉴了ResNet中的残差结构,解决了深层网络训练中常见的梯度消失、梯度爆炸和网络退化问题。本文将从历史背景、原理分析、代码实现到实验验证,全面解析Transformer中的残差连接机制。
残差连接的历史由来
残差连接最早可以追溯到2015年何凯明等人提出的ResNet,它使得残差连接/网络成为一种基准模型结构。残差连接主要解决了神经网络随着层数的增多变得难以训练的问题,特别是梯度消失、梯度爆炸和网络退化等问题。通过引入残差连接,网络可以拓展到更深的层数。
Transformer中的残差连接
Transformer也采用了残差连接(residual connection)这种标准结构。在Transformer的Encoder和Decoder中,层与层之间加入了ADD & Norm操作,其中ADD就是残差连接。具体来说,就是将本层的输出和本层的输入对应位置相加(本层的输出和本层的输入维度相等)作为最终的输出。
在Transformer的实现代码中,以Encoder为例:
output, slf_attn = self.self_att_layer(enc_input, enc_input, enc_input, mask=mask)
output = self.norm_layer(Add()([enc_input, output]))
其中,self_att_layer
是Encoder中的多头注意力,enc_input
是多头注意力的输入,output
是多头注意力的输出。通过Keras的Add()
算子将enc_input
和output
相加得到最终的output
。
深层网络的问题代码复现
为了验证深层网络带来的训练问题,我们构建了一个可以传入层数参数的Dense网络:
class Model(object):
def __init__(self, num_class, feature_size, layer_num=100, learning_rate=0.001, weight_decay=0.01, decay_learning_rate=1):
self.input_x = tf.placeholder(tf.float32, [None, feature_size], name="input_x")
self.input_y = tf.placeholder(tf.float32, [None, num_class], name="input_y")
self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob")
self.batch_normalization = tf.placeholder(tf.bool, name="batch_normalization")
self.global_step = tf.Variable(0, name="global_step", trainable=False)
tmp_tensor = self.input_x
for i in range(layer_num):
with tf.variable_scope('layer_{}'.format(i + 1)):
dense_out_1 = tf.layers.dense(tmp_tensor, 32)
dense_out_1 = batch_norm_layer(dense_out_1, is_training=self.batch_normalization, scope="bn{}".format(i + 1))
tmp_tensor = tf.nn.relu(dense_out_1)
with tf.variable_scope('layer_out'):
self.output = tf.layers.dense(tmp_tensor, 2)
self.probs = tf.nn.softmax(self.output, dim=1, name="probs")
with tf.variable_scope('loss'):
self.loss = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits_v2(logits=self.output, labels=self.input_y))
vars = tf.trainable_variables()
loss_l2 = tf.add_n([tf.nn.l2_loss(v) for v in vars if
v.name not in ['bias', 'gamma', 'b', 'g', 'beta']]) * weight_decay
self.loss += loss_l2
with tf.variable_scope("optimizer"):
if decay_learning_rate:
learning_rate = tf.train.exponential_decay(learning_rate, self.global_step, 100, decay_learning_rate)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
self.train_step = optimizer.minimize(self.loss, global_step=self.global_step)
with tf.variable_scope("metrics"):
self.accuracy = tf.reduce_mean(
tf.cast(tf.equal(tf.arg_max(self.probs, 1), tf.arg_max(self.input_y, 1)), dtype=tf.float32))
通过实验发现,随着网络层数的加深,模型在训练和测试的准确率都在下降。当网络达到35层以上时,已经很难进行学习收敛,准确率在50%左右。
深层网络的问题分析
深层网络主要存在三个问题:
- 梯度消失:反向传播中梯度的计算存在模型参数w的累乘,w接近0累乘导致梯度接近0梯度消失。
- 梯度爆炸:同理梯度消失,若w较大,累成导致梯度爆炸。
- 网络退化:理论上网络存在一个最优层数,超过这个层数带来的冗余结构的效果并不超过该最优层数下的模型效果,这些冗余层数会带来网络退化。
残差连接的作用通俗理解
残差连接的目的是使得就算模型的深度已经达到最优解,后面再增加冗余层也至少不会导致之前的效果下降。具体来说,就是将上一层的输出直接连接到下一层的输出,即上一层的输出直接和下一层的原始输出对应位置相加形成最终输出。
残差连接的做法是将上一层的输出直接连接到下一层的输出,即上一层的输出和下一层的原始输出对应位置相加形成最终输出。如图所示:
其中X代表一个逐渐逼近最优结果的上层输出,而F(x)代表残差,表示还可以再逼近最优效果的网络结构。当模型深度已经达到最优值的时候,残差连接可以自适应的将F(x)学习为全0,由于有relu的存在残差网络很容易将F(x)全部置为0,此时relu(F(x)+X)转化为relu(X),而由于relu的性质得知,relu(X)=X,因为X已经经过上一层的rule变换,再经过一次relu还是X,从而实现了恒等变换。
残差连接和GBDT类比
残差连接这种上一层作为基线,下一层拟合残差不断逼近最优结果的思想和GBDT很类似。GBDT用一个个基学习器拟合之前所有基学习器剩下的残差,而残差连接以X为基线,F(x)聚焦于还可学习的微小部分,差异在于GBDT每个基分类结果相加做logit即可得到预测结果,残差连接在网络中间层,最后还要套一层全连接进行任务分类。
残差连接的作用公式理解
在深层网络的问题分析那一段中有普通网络从x到y4的计算过程,加入残差连接之后y1到y4的计算如下:
y1 = σ(w1 * x + b1) + x
y2 = σ(w2 * y1 + b2) + y1
y3 = σ(w3 * y2 + b3) + y2
y4 = σ(w4 * y3 + b4) + y3
即:
其中X代表某层的输出,某个高层I的输出等于某个低层i输入加上两层之间所有残差F的结果,此时若要对低层的i求梯度,结果如下:
括号展开第一项直接就是高层I的梯度,直接作为一个因子直接作用到低层的i梯度,而不是像普通网络经过各种累乘放大或者缩小,等式右侧是一个累加,相比于原来的累成一定程度上降低了梯度爆炸和弥散的概率。
深层网络运用残差连接代码实践
修改深层网络的问题代码复现一段中的代码,使其在每一层的输出都加上该层的输入:
for i in range(layer_num):
with tf.variable_scope('layer_{}'.format(i + 1)):
dense_out_1 = tf.layers.dense(tmp_tensor, 32)
dense_out_1 = batch_norm_layer(dense_out_1, is_training=self.batch_normalization, scope="bn{}".format(i + 1))
if i != 0:
# 残差连接
dense_out_1 = tf.nn.relu(dense_out_1)
dense_out_1 = tf.add(dense_out_1, tmp_tensor)
# bn
tmp_tensor = tf.nn.relu(dense_out_1)
顺序采用input => dense => bn => relu => add => relu,参考这个:
同样是运行[3, 10, 20, 35, 50, 100]层,训练集的accuracy随iter的变化如下:
从训练集来看,35层和100层明显低于其他层但是差距并不大,50层和3,10,20没有明显差异,不同层下整体训练都能收敛,再看测试集:
3,10,20三者没有明显差异,35,50,100随着层数越来越大测试效果逐渐变差,但是也能平均保持在0.7的准确率,相比普通网络只有0.5出头已经有很大改观,从而验证了在深层网络加入残差连接的有效性。