问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

Llama改进之——均方根层归一化RMSNorm

创作时间:
作者:
@小白创作中心

Llama改进之——均方根层归一化RMSNorm

引用
CSDN
1.
https://blog.csdn.net/yjw123456/article/details/138139970

Llama模型在AI领域引起了广泛关注,其在模型架构和优化方面的创新为大语言模型的发展提供了新的思路。本文将介绍Llama模型中的一项重要改进——均方根层归一化(RMSNorm),并详细阐述其原理、优势以及具体实现方法。

LayerNorm

层归一化(LayerNorm)对Transformer等模型来说非常重要,它可以帮助稳定训练并提升模型收敛性。LayerNorm针对一个样本所有特征计算均值和方差,然后使用这些来对样本进行归一化:

$$
\mu = \frac{1}{H}\sum_{i=1}^H x_i, \quad \sigma = \sqrt{\frac{1}{H}\sum_{i=1}^H (x_i - \mu)^2}, \quad N(\mathbf{x}) = \frac{\mathbf{x}-\mu}{\sigma}, \quad \mathbf{h} = \mathbf{g} \odot N(\mathbf{x}) + \mathbf{b} \tag{1}
$$

这里 $\mathbf{x} = (x_1, x_2, \cdots, x_H)$ 表示某个时间步LN层的输入向量表示,向量维度为 $H$;$\mathbf{h}$ 实LN层的输出;$\mathbf{g}, \mathbf{b}$ 是两个可学习的参数。

为什么层归一化有用?一些解释如下:

  1. 减少内部协变量偏移(Internal Covariate Shift): 内部协变量偏移是指在深度神经网络的训练过程中,每一层输入的分布会发生变化,导致网络的训练变得困难。层归一化通过对每一层的输入进行归一化处理,可以减少内部协变量偏移,使得每一层的输入分布更加稳定。
  2. 稳定化梯度: 层归一化有助于保持每一层输出的均值和方差稳定,从而使得梯度的传播更加稳定。这有助于减少梯度消失或梯度爆炸的问题,提高梯度在网络中的流动性,加快训练速度。
  3. 更好的参数初始化和学习率调整: 通过层归一化,每一层的输入分布被归一化到均值为0、方差为1的标准正态分布,这有助于更好地初始化网络参数和调整学习率。参数初始化与学习率调整的稳定性对模型的训练效果至关重要。
  4. 增强模型的泛化能力: 层归一化可以减少网络对训练数据分布的依赖,降低了过拟合的风险,从而提高模型的泛化能力。稳定的输入分布有助于模型更好地适应不同数据集和任务。

RMSNorm

虽然LayerNorm很好,但是它每次需要计算均值和方差。RMSNorm的思想就是移除(1)式中 $\mu$ 的计算部分:

$$
\bar{x}i = \frac{x_i}{\text{RMS}(\mathbf{x})} g_i \quad \text{RMS}(\mathbf{x}) = \sqrt{\frac{1}{H} \sum{i=1}^H x_i^2} \tag{2}
$$

同时在实现也可以移除平移偏置 $\mathbf{b}$。

单看(2)式的话,相当于仅使用 $\mathbf{x}$ 的均方根来对输入进行归一化,它简化了层归一化的计算,变得更加高效,同时还有可能带来性能上的提升。

实现

RMSNorm的实现很简单:

import torch
import torch.nn as nn
from torch import Tensor

class RMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(hidden_size))
    
    def _norm(self, hidden_states: Tensor) -> Tensor:
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        return hidden_states * torch.rsqrt(variance + self.eps)
    
    def forward(self, hidden_states: Tensor) -> Tensor:
        return self.weight * self._norm(hidden_states.float()).type_as(hidden_states)

torch.rsqrttorch.sqrt的倒数;eps是一个很小的数,防止除零;hidden_states.float()确保了标准差计算的精确度和稳定性,然后在forward方法中,通过.type_as(hidden_states)将结果转换回原来的数据类型,以保持与输入张量相同的数据类型,使得归一化处理后的结果与输入数据类型一致。

下面通过一个简单的网络来测试一下:

import torch
import torch.nn as nn
from torch import Tensor

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.linear = nn.Linear(in_features=10, out_features=5)
        self.rmsnorm = RMSNorm(hidden_size=5)

    def forward(self, x):
        x = self.linear(x)
        x = self.rmsnorm(x)
        return x

net = SimpleNet()

input_data = torch.randn(2, 10)  # 2个样本,每个样本包含10个特征

output = net(input_data)

print("Input Shape:", input_data.shape)
print("Output Shape:", output.shape)

输出结果:

Input Shape: torch.Size([2, 10])
Output Shape: torch.Size([2, 5])

参考

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号