深度学习中的Batch Normalization技术详解
深度学习中的Batch Normalization技术详解
如是我闻:Batch Normalization(批归一化,简称 BN)是2015 年由 Ioffe 和 Szegedy 提出的一种加速深度神经网络训练并提高稳定性的技术。它的核心思想是:在每一层的输入进行归一化,使其均值接近 0,方差接近 1,从而减少不同批次数据的分布变化(Internal Covariate Shift),提高训练效率,并降低对超参数的敏感性。
1. 为什么需要 Batch Normalization?
(1) 训练过程中数据分布会变化
- 在深度神经网络中,每一层的输入数据并不是固定的,而是来自前一层的输出。
- 随着训练进行,前几层的权重不断变化,导致后面层的输入数据分布发生变化(即Internal Covariate Shift)。
- 这种变化会让网络不断适应新的数据分布,影响收敛速度,甚至可能导致梯度消失或梯度爆炸问题。
(2) 归一化输入可以加速收敛
- 在训练神经网络时,通常对输入数据进行归一化(标准化),即让输入数据的均值为 0,方差为 1:
x' = \frac{x - \mu}{\sigma} - 但是,如果只对输入数据归一化,而不对隐藏层的输入归一化,那么后续层仍然可能受到数据分布变化的影响。
(3) Batch Normalization 解决了什么问题?
- 减少 Internal Covariate Shift,让每层的输入分布更加稳定。
- 加速收敛,使网络能够使用更大学习率进行训练。
- 减少梯度消失和梯度爆炸问题,提高深度网络的训练稳定性。
- 减少对超参数(如学习率、权重初始化)的依赖,使得网络更容易调参。
- 有一定的正则化效果,降低过拟合的风险。
2. Batch Normalization 的计算过程
假设当前网络有一层的输入是 x,Batch Normalization 计算过程如下:
(1) 计算均值和方差
对一个 batch 的数据 B = {x_1, x_2, ..., x_m},计算该 batch 的均值 \mu_B 和方差 \sigma_B^2:
\mu_B = \frac{1}{m} \sum_{i=1}^{m} x_i
\sigma_B^2 = \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu_B)^2
(2) 归一化数据
用均值和标准差对数据进行标准化:
\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}
其中 \epsilon 是一个很小的数,防止除以 0。
(3) 线性变换(可学习参数)
为了保证 BN 不会限制网络的表达能力,我们引入两个可学习参数:
- 缩放参数 \gamma(scale):控制归一化后的分布的尺度。
- 平移参数 \beta(shift):让归一化后的数据能够恢复到合适的分布。
最终输出:
y_i = \gamma \hat{x}_i + \beta
这样,BN 既能保证数据的稳定性,又能让网络学到适当的分布。
3. Batch Normalization 在网络中的作用
BN 层通常可以添加到全连接层或卷积层之后,ReLU 之前:
- 在全连接网络(MLP)中:
z = Wx + b 之后,加入 Batch Normalization:
\hat{z} = \frac{z - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}
然后乘以 \gamma 并加上 \beta:
y = \gamma \hat{z} + \beta
最后再经过激活函数(如 ReLU)。
- 在 CNN 里,对每个通道的特征图进行归一化:
\mu_B = \frac{1}{m \cdot h \cdot w} \sum_{i=1}^{m} \sum_{j=1}^{h} \sum_{k=1}^{w} x_{i, j, k}
其中 m 是 batch 大小,h, w 是特征图的高度和宽度。
4. Batch Normalization 的优缺点
✅ 优点
- 加速训练(可以使用更大学习率)。
- 减少梯度消失/梯度爆炸问题。
- 提高网络的泛化能力,有一定的正则化效果(但不完全等同于 Dropout)。
- 降低对权重初始化的敏感性。
❌ 缺点
- 对小 batch 不友好(因为均值和方差计算会不稳定)。
- 在 RNN 里效果不好(时间序列数据的统计特性不同)。
- 推理时计算均值和方差会增加计算量。
6. BN 和其他归一化方法的对比
归一化方法 | 应用场景 | 归一化维度 | 适用于 RNN? |
---|---|---|---|
Batch Normalization | CNN, MLP | 在 batch 维度计算均值和方差 | ❌ |
Layer Normalization | RNN, Transformer | 在特征维度归一化(不依赖 batch) | ✅ |
Instance Normalization | 风格迁移 | 在每个样本的特征图上归一化 | ❌ |
Group Normalization | 小 batch CNN | 在多个通道分组归一化 | ✅ |
7. 总的来说
- Batch Normalization(BN)是深度学习中的一个重要归一化技术,它的目标是减少 Internal Covariate Shift,提高训练速度和稳定性。
- 核心步骤:
- 计算 batch 均值和方差。
- 归一化数据,使其均值 0,方差 1。
- 使用可学习参数 \gamma 和 \beta 进行缩放和平移。
- BN 主要作用:
- 加速收敛,可以使用更大学习率。
- 减少梯度消失/梯度爆炸问题,提高稳定性。
- 有一定的正则化作用,降低过拟合。
- 缺点:
- 小 batch 训练效果较差。
- 在 RNN 里效果不好。