深度学习中神奇的lr schedule算法—余弦退火
深度学习中神奇的lr schedule算法—余弦退火
深度学习中的优化函数选择是建立模型过程中非常重要的环节,其中学习率的选择更是超参数调节中的关键部分。随着模型规模的不断扩大,固定的学习率已经无法满足需求,因此学习率调度(lr scheduler)算法应运而生。本文将重点介绍其中一种非常典型且广泛使用的算法——余弦退火算法,探讨其在模型训练中的重要作用。
1. lr schedule的诞生
深度学习的概念最早在2006年被提出,其核心思想是通过寻找适当的模型、拟合函数和损失函数,利用优化算法来更新模型中不同数值的权重,从而达到训练的效果。在这个过程中,学习率(learning rate)是一个非常重要的参数。优化算法的基本思想是根据损失函数在当前点的梯度方向和幅度来逼近局部最优解(鞍点),而学习率就是控制自变量在该点随着梯度变化而移动的幅度的参数。
但是,将学习率设定为一个固定的超参数(如最经典的随机梯度下降算法SGD)会带来一些问题。学习率过小可能会导致优化速度过慢,损失函数无法收敛;学习率过大则可能会造成数值爆炸或者使函数在最优解附近反复震荡无法收敛。下面的代码直观地展示了这个问题:
import matplotlib.pyplot as plt
import numpy as np
# 初始函数记为 y=x**2
# 这样的函数可以非常清晰地表现出超参数学习率带来的一些后果
def get_learning_rate(lr, x):
epoch = 10
results = [x]
for i in range(epoch):
x -= lr * 2 * x
# x**2 的导数为 2*x,这里模仿随机梯度下降函数 sgd 对 x 进行迭代,观察 x 的变化
results.append(x)
return results
def show_x(results):
n = max(abs(min(results)), abs(max(results)))
f_line = np.arange(-n, n, 0.1)
plt.plot(f_line, [x**2 for x in f_line])
plt.plot(results, [x**2 for x in results], '-o')
plt.xlabel('x')
plt.ylabel('f(x)')
plt.show()
x = 10
lr = float(input("请输入学习率: "))
show_x(get_learning_rate(lr, x))
通过观察不同学习率下 x 通过 lr 权重的梯度下降后的散点图,我们可以发现,选定合适的超参数 lr 确实能达到局部最优解,但是往往也会出现数值爆炸或是局部震荡的情况。设定超参数学习率在面对大型模型的时候往往无能为力,难以达到最优解。
lr schedule 算法应运而生,lr schedule 算法的基本原理就是通过权重衰减改变学习率,从而实现在一开始可以设定较大的学习率快速逼近局部最优解,随后通过减小学习率使损失函数逐渐收敛于局部最优解,这种算法同时保证了训练的效率和训练的准确率,逐渐被广泛应用于大型模型的训练。
2. 应用 lr schedule 思想的算法
2.1 AdaGrad 算法
AdaGrad 算法的核心思想就是对学习率采用一个小批量随机梯度 gt 按照元素平方的累加(在大规模模型的训练中,通常采用矩阵运算来加快学习)得到权重衰减的变量st,在 t=0 时,初始化s0 中所有元素为 0:
接着对 xt 进行迭代,其中ε是为了保证数值稳定的常数:
AdaGrad 算法通过中间量st的简单迭代实现了学习率权重衰减。
2.2 RMSProp 算法
RMSProp 算法与 AdaGrad 算法最根本的区别在于第一步,RMSProp 算法对st迭代的时候设置了权重超参数α,通过对权重的设定我们可以控制学习率衰减的程度,具体实现如下:
该算法解决了 AdaGrad 训练后期依然有可能出现的学习率过小而导致无法逼近局部最优解的情况。
2.3 AdaDelta 算法
AdaDelta 算法与 RMSProp算法的不同就在于前者使用了一个新的状态变量Δxt,该状态变量在 t=0 时被初始化为 0,同时将gt,xt和Δxt的迭代作了如下调整:
该算法与之前算法的核心区别在于使用了状态变量Δxt来替代了原来的超参数η。
2.4 总结
现在流行的优化算法中基本上都包括了 lr schedule 的思想,例如 Adam 算法等,lr schedule 已经成为了深度学习中最主流的优化思想,兼顾了学习效率与模型准确率。
3. 余弦退火
3.1 余弦退火算法
3.1.1 上述算法面对多峰函数可能出现的问题
尽管上述算法都已经运用了一些 lr schedule 的思想进行了学习率的衰减,但是依然可能出现一些问题。有时候模型的损失函数可能是一个多峰函数,而上述算法包括更为强大的 Adam 算法对学习率权重的衰减都是递减的,通过上述算法,损失在达到局部最小值之后由于过小的学习率导致损失函数会收敛于局部最小值附近,但是模型训练需要达到全局最小值,这些递减的权重算法无法跳出局部最小值。
图 2-1 为一个简单的多峰函数,其中红色点为该函数的全局最小值,绿色点为局部最小值,如果初始化模型后起点为紫色的 start point,经过上述算法(AdaGrad,RMSProp,Adadelta 等)训练后,损失函数会收敛于绿色点附近,但是全局最优解在红色点,因此上述算法并不能应对此类情况。在这种情况下,一种更优化的算法——余弦退火算法应运而生。
3.1.2 余弦退火算法的基本数学原理
余弦退火算法属于一种 SGDR 算法,其中 R 即 restart,其特点就是经过 epoch_ix 次迭代后学习率会重新调整至一个较大的值,随后重新进行学习率的权重衰减。因此这一类算法具有类似于周期性的性质。
余弦退火算法选择利用余弦函数进行权重衰减的原因在于余弦函数本身的特点。观察余弦函数 f(x)=cosx 的导函数 f’(x)=-sinx,余弦函数在二分之一周期[2kπ, 2kπ + π/2]内,余弦函数导函数的绝对值先增大后减小,对应余弦函数的斜率先较水平,随后迅速降低,最后缓慢回到水平,这样的变化正好满足深度学习对学习率权重衰减的要求。学习率先保持一个较大的值缓慢下降,保证损失函数可以快速接近局部最优解,在接近最优解之后学习率权重迅速减小,开始缓慢接近最优解,最后保持一个较小的值使损失函数收敛于最优解附近。同时,余弦函数本身是一个周期函数,因此可以用单个数学表达式去表现出每一段区间上学习率的变化,降低了算法的复杂性,节约了更多的内存给训练。
根据余弦退火算法的思想,可以对这个算法给定一些参数。首先需要设定学习率的最大值ηmax和最小值ηmin,在经过 epoch_ix 次计算后,学习率达到全局最小值ηmin,在epoch_ix+1 次时,学习率重新回到全局最大值ηmax,进入下一个 epoch 周期重新进行权重衰减,这个步骤称之为重启(restart),每一次重启后,对参数 Ti 的下标进行加一,表示该函数进行了 i 次重启。重启后学习率重新回到一个较大的值,但是迭代的参数不会重启而依然保持在旧的 xt,从而可以跳出局部最优解(如果当前的 xt 已经是全局最小值了,那么在多次重启后依然会回到 xt)。Tcur 即指当前进行的学习次数。在 2017 年时,Ilya Loshchilov & Frank Hutter 在论文《SGDR: STOCHASTIC GRADIENT DESCENT WITH WARM RESTARTS》中给出了一种沿用至今的余弦退火数学表达式(该表达式只展示了一个周期内的权重衰减,没有包含重启):
图 2-2 一个简单的余弦退火函数中学习率的变化
3.2 余弦退火算法的具体实现
3.2.1 step_scale
在实际的训练中,在一定次数的训练后,损失函数可能已经达到全局最优解附近,每次重启后的学习率不必返回到一开始设定的全局最大值(否则会浪费大量时间在跳出最优解和回到最优解上,也会造成难以收敛的情况),因此在每次重启后,设定的返回值要逐渐减小,因此设定 step_scale 参数,可以使函数每次重启的时候乘上一个 step_scale 的权重,实现每次重启后学习率的最大值逐渐减小的效果,如下图所示:
图 2-3 设定 step_scale 后学习率的变化
图 2-4 重启时未达到最小值的情况
3.2.2 Tmult 参数
同样的,在设定重启所间隔的步数时,也考虑到了同样的问题。到了训练后期,随着损失函数逐渐靠近全局最优解,在这时用测试集去测试正确率,总能在学习率较低的时候获得一个较高的正确率,但是在重启后测试,正确率会降低,这是因为重启后较大的学习率会导致损失函数冲出全局最优解而重新开始训练,因此一开始设定的重启周期将不再适用(本质上重启是为了越过局部最小值而设定的,当损失函数已经接近全局最优解时,重启已经没有意义),需要引入新的参量 Tmult,新参量的引入实现了如下效果:
T_0 参数表示函数第一次重启时的 epoch;
当 Tmult 没有被引入时(default= 1),那么学习率将在 T_0, 2T_0, 3T_0, ......, i*T_0 处回到最大值,具体图像参考图 2-1;
例如,当设定 T_0=5 时,学习率会在 5,10,15,......,5i 的位置重启;
引入参量 Tmult,学习率将在 T_0, (1+Tmult)*T_0, (1+Tmult+Tmult*2)*T_0, ......, (1+Tmult+Tmult**2+......+Tmult*i)*T_0 处回到最大值;
例如,当设定 T_0=5,Tmult=2 时候,学习率会在 5,15,35,......处重启。
图 2-5 引入变量 Tmult 后学习率的变化
图 2-5 展示了 Tmult 被引入后学习率的变化图像,每次重启所需要的 epoch 会逐渐变大,这样到了训练后期,学习率不会再有重启的过程,而是一直保持下降的趋势直到训练结束,可以有效地避免训练后期重启后损失函数冲出全局最优解的情况。
3.2.3 warm up
在实战中,大部分模型的参数都是随机初始化的,因此初始参量的数值可能不稳定,训练时如果贸然将学习率重启为一个较大的值可能会发生数值爆炸导致后续训练无法完成,因此需要引入一个新的步骤——warm up。在 warm up 步骤中,学习率会从ηmin逐渐增大到 ηmax,保证了模型的数值不会在重启后爆炸,而是短暂的学习一段时间将数值稳定下来再调大学习率开始寻找最优解。
图 2-6 加入 warm up 机制后的学习率变化
4. 余弦退火的优势与不足——实例展示
4.1 优势一——快速逼近最优解,无需调节超参数学习率
余弦退火一大重要的职能就是自动迭代学习率,从而可以在训练前期快速逼近全局最优解达到较高的效率和稳定性。下面的代码可以展示余弦退火在这方面的优势,其中选择的对比函数为随机梯度下降(SGD),损失函数设定为 y=x**2,由于没有局部最优解的存在,因此我们不需要进行重启,同时为了减少余弦退火本身的算法复杂性带来的时间差,我们将省略 step_scale,Tmult 以及 warm up 的选择和使用。
代码如下:
上述代码的输出结果:
可以看到,面对相同的损失函数,余弦退火算法比设定超参数的 SGD 算法可以以更少的训练时间和训练次数达到全局最优解。
4.2 优势二——可以越过局部最优解,寻找全局最优解
这里我们使用图 2-1 展示的函数进行训练,同样为了减少余弦退火本身的算法复杂性带来的时间差,我们将省略 step_scale,Tmult 以及 warm up 的选择和使用,但这次我们需要设定重启值。
通过下面两张图片的对比我们可以很直观地发现重启带来的作用:
图 3-1 采用 SGD 算法,从 x=5.8 处开始迭代的过程,收敛于局部最小值
图 3-2 采用余弦退火算法,在重启一定次数后冲出局部最小值
图 3-3 采用余弦退火算法,冲出局部最小值后逐渐收敛于全局最小值
4.3 不足之处
余弦退火算法一大不足之处在于需要设定较多的超参数,在训练初期需要反复尝试以寻找合适的ηmax、ηmin、Ti 和 Tmult。依然以图 2-1 的函数为例展示超参数未设置恰当带来的影响。
图 3-4、图 3-5 分别为设定过大的重启周期和过大的ηmax,导致函数跌倒无法收敛
尽管余弦退火算法仍有不足之处,但是其为我们训练大型模型带来的巨大便利性依然是不可忽视的。
参考文献
[1] Ilya Loshchilov, Frank Hutter, University of Freiburg, Freiburg, Germany. SGDR: Stochastic Gradient Descent With Warm Restarts