问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

大模型变“健忘症患者“?——公式揭秘灾难性遗忘的防治之道

创作时间:
作者:
@小白创作中心

大模型变“健忘症患者“?——公式揭秘灾难性遗忘的防治之道

引用
CSDN
1.
https://m.blog.csdn.net/qq_37148940/article/details/145791053

大模型在学习新任务时往往会遗忘旧任务的知识,这种现象被称为“灾难性遗忘”。本文通过类比和公式推导的方式,详细解释了弹性权重巩固(EWC)、知识蒸馏、梯度投影和回放策略等解决方案,并提供了EWC方法的代码实现。

核心结论:模型记忆就像橡皮泥,捏新形状时旧痕迹会消失

“给橡皮泥加铁丝骨架(正则化),或定期回放旧形状(回放策略),就能保持新旧记忆共存”

公式推演与类比解释

1. 核心公式对比表

公式名称
数学表达式
通俗解释
类比场景
弹性权重巩固
$L_{\text{total}} = L_{\text{new}} + \lambda \sum_i F_i (\theta_i - \theta_{\text{old},i})^2$
重要神经连接加防护罩
大脑重要记忆区设置保护屏障
知识蒸馏
$L_{\text{KD}} = \alpha T^2 \text{KL}(q^{\text{old}}
q^{\text{new}}) + (1-\alpha)L_{\text{task}}$
梯度投影
$\min
g - g_{\text{new}}
^2 \quad \text{s.t.} \quad g^T g_{\text{old}} \geq 0$
回放策略
$L = \mathbb{E}{(x,y)\sim D{\text{new}}} [l(f(x),y)] + \beta \mathbb{E}{(x,y)\sim D{\text{old}}} [l(f(x),y)]$
新旧知识交替复习
背单词时定期复习前几课内容

2. 核心公式详解

公式1:弹性权重巩固(EWC)

$$
L = \underbrace{L_{\text{new}}(\theta)}{\text{新任务损失}} + \underbrace{\frac{\lambda}{2} \sum_i F_i (\theta_i - \theta{\text{old},i})^2}_{\text{记忆保护项}}
$$

参数
数学符号
类比解释
作用机制
重要性矩阵
$F_i$
神经连接的重要程度评分
Fisher信息矩阵对角元素
旧参数
$\theta_{\text{old}}$
橡皮泥的原始形状
预训练模型参数
惩罚系数
$\lambda$
新旧记忆的平衡调节器
超参数调节新旧任务权重

案例应用:在BERT微调中,EWC保护语言理解相关的关键参数不被过度修改

公式2:梯度投影约束

$$
\begin{cases}
\min |g - g_{\text{new}}|^2 \
g^T g_{\text{old}} \geq 0
\end{cases}
$$

物理意义
类比解释
$g_{\text{new}}$
新任务梯度方向
想学习的新技能方向
$g_{\text{old}}$
旧任务梯度方向
已掌握的技能方向
$g^T g_{\text{old}}$
方向一致性检测
确保不背道而驰的指南针

3. 进阶公式推导

动态网络扩展

$$
\theta = \theta_{\text{base}} \oplus \theta_{\text{new}}
$$

$$
\mathcal{M}(x) = f_{\theta_{\text{base}}}(x) + g_{\theta_{\text{new}}}(x)
$$

参数隔离策略

$$
\mathbb{P}(w_i \text{被冻结}) = \frac{e^{-\gamma F_i}}{\sum_j e^{-\gamma F_j}}
$$

在线回放优化

$$
L_{\text{replay}} = \frac{1}{N}\sum_{i=1}^N \mathbb{E}{x\sim M_i} [\text{KL}(f{\theta}(x)||f_{\theta_i}(x))]
$$

代码实战:EWC方法实现

import torch
import numpy as np
from torch.nn.functional import kl_div

class EWC_Optimizer:
    def __init__(self, model, lambda_=0.5):
        self.model = model
        self.lambda_ = lambda_
        self.fisher_matrix = {}
        self.old_params = {}
        
    def calculate_fisher(self, dataset):
        # 计算Fisher信息矩阵
        fisher = {}
        for name, param in self.model.named_parameters():
            fisher[name] = torch.zeros_like(param)
            
        for data, _ in dataset:
            self.model.zero_grad()
            output = self.model(data)
            loss = torch.nn.functional.nll_loss(output, target)
            loss.backward()
            
            for name, param in self.model.named_parameters():
                fisher[name] += param.grad.pow(2) / len(dataset)
                
        self.fisher_matrix = fisher
        self.old_params = {n:p.clone() for n,p in model.named_parameters()}
        
    def penalty_loss(self):
        loss = 0
        for name, param in self.model.named_parameters():
            old_param = self.old_params[name]
            fisher = self.fisher_matrix[name]
            loss += (fisher * (param - old_param).pow(2)).sum()
        return self.lambda_ * loss

# 训练示例
ewc = EWC_Optimizer(model, lambda_=0.8)
ewc.calculate_fisher(old_task_loader)  # 预训练任务数据
for epoch in range(100):
    for data, target in new_task_loader:
        optimizer.zero_grad()
        output = model(data)
        task_loss = F.cross_entropy(output, target)
        ewc_loss = ewc.penalty_loss()
        total_loss = task_loss + ewc_loss
        total_loss.backward()
        optimizer.step()

# 可视化Fisher矩阵
plt.figure(figsize=(10,6))
plt.imshow(np.log(ewc.fisher_matrix['layer1.weight'].cpu().numpy()+1e-9), 
           cmap='viridis')
plt.colorbar()
plt.title("Fisher Information Heatmap")
plt.show()

可视化解析

  1. Fisher矩阵热力图:显示神经网络各层参数重要性分布,颜色越亮表示对旧任务越关键
  2. 损失函数曲线:蓝色曲线显示新任务损失下降,红色曲线显示旧任务损失波动范围
  3. 参数分布雷达图:对比显示EWC约束下参数偏移量明显小于普通微调

公式体系总览

公式类型
典型代表
防御机制
正则化公式
EWC约束项
参数空间锚定
动态架构公式
Progressive Neural Networks
网络结构扩展
记忆回放公式
iCaRL选择策略
样本重放机制
元学习公式
MAML优化器
快速适应机制
© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号