问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

Adam优化器理解和代码实现

创作时间:
作者:
@小白创作中心

Adam优化器理解和代码实现

引用
CSDN
1.
https://blog.csdn.net/weixin_48435461/article/details/143946446

Adam优化器理解和代码实现

Adam优化器的公式理解和代码实现

一、公式理解

1. 梯度计算

首先,根据损失函数$f(\theta)$对参数$\theta$计算梯度:

$$
g_t = \nabla_{\theta} f_t(\theta_{t-1})
$$

作用:获取当前参数$\theta$下的梯度$g_t$,表示在当前点向何处移动可以减少损失。

2. 一阶矩的更新(动量)

更新一阶矩估计$m_t$:

$$
m_t = \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t
$$

含义

  • $m_t$:梯度的指数加权移动平均,类似于动量的概念。
  • $\beta_1$:控制过去梯度对当前梯度的影响,常取$\beta_1 = 0.9$。
  • $g_t$:当前的梯度。
  • $\beta_1 \cdot m_{t-1}$:表示保留之前动量的部分。
  • $(1 - \beta_1) \cdot g_t$:表示当前梯度对动量的贡献。

3. 二阶矩的更新(未中心化方差)

更新二阶矩估计$v_t$:

$$
v_t = \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot g_t^2
$$

含义

  • $v_t$:梯度平方的指数加权移动平均,用于估计梯度的方差(幅度的大小)。
  • $\beta_2$:控制过去梯度平方对当前梯度平方的影响,常取$\beta_2 = 0.999$。
  • $g_t^2$:当前梯度平方,反映梯度的幅度大小。

目标:为每个参数提供自适应的学习率调整,幅度大的梯度会被减弱,幅度小的梯度会被放大。

4. 偏差校正

由于一阶矩和二阶矩在初始时刻$t$较小时存在偏差,需要进行校正。

一阶矩校正:

$$
\hat{m}_t = \frac{m_t}{1 - \beta_1^t}
$$

含义:将$m_t$修正为无偏估计,减小由于初始化为 0 带来的偏差。

二阶矩校正:

$$
\hat{v}_t = \frac{v_t}{1 - \beta_2^t}
$$

含义:将$v_t$修正为无偏估计,同样消除初始值为 0 时的偏差。

5. 参数更新

使用校正后的矩估计更新参数:

$$
\theta_t = \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}
$$

含义

  • $\alpha$:学习率。
  • $\hat{m}_t$:校正后的一阶矩,表示方向。
  • $\sqrt{\hat{v}_t} + \epsilon$:校正后的二阶矩,用于调整学习率并避免分母为 0。

目标:自适应地调整步长,使得学习率根据梯度的统计特性动态变化。

6. 时间步更新

在每一轮更新后,时间步$t$加 1:

$$
t = t + 1
$$

含义:记录当前是第几次迭代,便于计算偏差校正的因子$\beta_1^t$和$\beta_2^t$。

总结

  • 一阶矩$m_t$:累积了梯度的动量信息,主要影响更新的方向。
  • 二阶矩$v_t$:追踪梯度的幅度信息,控制每个参数的学习率大小。
  • 偏差校正:消除了初始值为零带来的偏差。
  • 参数更新公式:自适应地调整了学习率,使得学习率根据梯度的统计特性动态变化。

二、代码实现

以$f(x) = x^3 + 3x^2 - 2x$为例

Adam优化与SGD(随机梯度优化)比较

import matplotlib.pyplot as plt

# 函数定义
def f(x):
    return x**3 + 3*x**2 - 2*x

# 导数定义
def gred_f(x):
    return 3*x**2 + 6*x - 2

def adam_optimizer(gred_f, x0, alpha=0.1, beta1=0.9, beta2=0.999, epsilon=1e-8, max_iter=100):
    x = x0
    m = 0 # 初始化动量
    v = 0 # 初始化方差
    t = 0
    x_values = [] # 存储更新后的x
    f_values = [] # 存储更新后f(x)值
    while t < max_iter:
        t += 1
        g = gred_f(x) # 计算梯度
        m = beta1 * m + (1 - beta1) * g
        v = beta2 * v + (1 - beta2) * g**2
        m_hat = m / (1-beta1**t)
        v_hat = v / (1-beta2**t)
        x = x - alpha * m_hat / (v_hat**0.5 + epsilon)
        x_values.append(x)
        f_values.append(f(x))
    return x_values, f_values

def sgd_optimizer(gred_f, x0, alpha=0.1, max_iter=100):
    x = x0
    x_values = []
    f_values = []
    while len(x_values) < max_iter:
        g = gred_f(x)
        x = x - alpha * g
        x_values.append(x)
        f_values.append(f(x))
    return x_values, f_values

绘图比对

x0 = -1

# 计算更新值
adam_x_values, adam_f_values = adam_optimizer(gred_f, x0)
sgd_x_values, sgd_f_values = sgd_optimizer(gred_f, x0)

# 绘制图片
plt.figure(figsize=(10, 5))
plt.plot(adam_f_values, label='Adam Optimizer', linestyle='-', linewidth=2)
plt.plot(sgd_f_values, label='SGD Optimizer', linestyle='--', linewidth=2)
plt.xlabel('Iterations')
plt.ylabel('f(x)')
plt.title('Comparison of Adam and SGD on $f(x) = x^3 + 3x^2 - 2x$')
plt.legend()
plt.grid(True)
plt.show()

结果分析

分析对比图可以看到,Adam优化器其实更具备逃离鞍点的能力,因为是波浪式的,所以不容易陷入局部最优;

需要注意的是,代码实现只是绘制了一个一元函数,但实际我们面对的是多元的神经网络,这时的SGD相较于Adam的搜索能力和收敛速度就显得比较乏力了

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号