问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

指数加权移动平均(EMA)在稳定扩散模型中的应用

创作时间:
作者:
@小白创作中心

指数加权移动平均(EMA)在稳定扩散模型中的应用

引用
CSDN
1.
https://blog.csdn.net/weixin_48524215/article/details/140687147

本文介绍了指数加权移动平均(Exponential Moving Average,EMA)在稳定扩散模型(Stable Diffusion)中的应用。文章详细解释了EMA的原理、在稳定扩散模型中的具体应用以及如何在代码中实现EMA。

1. 移动平均在稳定扩散模型中的应用(SMA&EMA)

1.1 简单移动平均(SMA,无权重MA)

1.2 指数加权移动平均(EMA,加权MA)

在稳定扩散模型的训练中,指数加权移动平均(EMA)是一种用于优化机器学习模型(特别是神经网络)的技术。与简单移动平均不同,EMA对近期数据点赋予更多权重,使其对近期变化的响应更灵敏。

1.2.1 EMA在稳定扩散模型中的应用

在稳定扩散模型中,EMA应用于模型参数的训练过程,以创建一个平滑的模型版本。这在机器学习中特别有用,因为训练过程可能很嘈杂,模型参数在收敛到最优解时可能会振荡。通过维护模型参数的EMA,训练过程可以受益于以下方面:

  1. 平滑:EMA平滑参数更新,减少噪声影响,使训练过程更稳定。
  2. 更好泛化:EMA版本的模型通常在未见过的数据上表现更好,因为EMA倾向于选择随时间保持一致的参数值。
  3. 防止过拟合:通过时间平均参数,EMA可以帮助缓解过拟合,特别是在模型可能过快收敛到次优解的情况下。

在训练SD时的MSE Loss在梯度下降过程中是上下震荡的,对应的模型参数也在震荡,可以用EMA取得这些模型参数震荡值的中间值,这个模型参数的中间值也就能更好的代表所有时刻模型参数的平均水平,让模型获得了更好的泛化能力。

稳定扩散2使用指数加权平均(EMA),它维护权重的指数加权平均。在每个时间步,EMA模型通过取当前EMA模型的0.9999倍加上最新前向和后向传递后的权重的0.0001倍来更新。默认情况下,整个训练期间都会应用EMA算法。然而,由于需要在每一步读写所有权重,这可能会很慢。

每个时间步都对所有参数进行EMA代价较大,因为要在每个时刻读写模型的全部参数

为了降低计算EMA的代价,我们仅在最后时间段进行EMA计算。具体来说,我们训练1,400,000个批次,并仅在最后50,000个步骤应用EMA,这大约占训练期的3.5%。前1,350,000个迭代的权重会以0.9999的因子衰减,因此它们在最终模型中的总贡献权重不到1%。使用这种技术,我们可以避免为96.5%的训练添加开销,同时仍然获得几乎等效的EMA模型。

1.2.2 在稳定扩散模型中的实现

在扩散模型的训练过程中,模型权重的EMA会与常规更新一起更新。这是一个典型的过程:

  1. 初始化EMA权重:在训练开始时,将EMA权重初始化为与模型初始权重相同。
  2. 在训练期间更新:在每次批次更新后,使用上述公式更新EMA权重。这需要为EMA存储一组单独的权重。
  3. 用于推理:在训练结束时,使用EMA权重而不是原始模型权重进行推理。这是因为EMA权重代表了一个更稳定且可能表现更好的模型版本。

1.2.3 实用考虑

  1. 选择α:平滑因子α是一个需要仔细选择的超参数。常见的做法是基于迭代次数或周期设置α,例如α = 2 / (N + 1),其中N是迭代次数。
  2. 性能开销:维护EMA权重需要额外的内存和计算开销,但模型稳定性和性能的提升通常可以抵消这些成本。

EMA类实现

class EMA:
    def __init__(self, beta):
        super().__init__()
        self.beta = beta  # EMA的平滑因子
        self.step = 0  # 步数计数器

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

    def step_ema(self, ema_model, model, step_start_ema=2000):
        if self.step < step_start_ema:
            self.reset_parameters(ema_model, model)
        else:
            self.update_model_average(ema_model, model)
        self.step += 1

    def reset_parameters(self, ema_model, model):
        ema_model.load_state_dict(model.state_dict())

训练代码实现

def train(args):
    device = args.device
    model = UNET().to(device)
    model.train()
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)
    mse = nn.MSELoss()
    logger = SummaryWriter(os.path.join("runs", args.run_name))
    len_train = len(train_loader)

    ema = EMA(0.995)
    ema_model = copy.deepcopy(model).eval().requires_grad_(False)

    print('Start into the loop !')
    for epoch in range(args.epochs):
        logging.info(f"Starting epoch {epoch}:")
        progress_bar = tqdm(train_loader)
        optimizer.zero_grad()

        accumulation_steps = 4
        for batch_idx, (images, captions) in enumerate(progress_bar):
            images = images.to(device)
            images = torch.squeeze(images, dim=1)
            captions = captions.to(device)
            text_embeddings = torch.squeeze(captions, dim=1)
            timesteps = ddpm_sampler.sample_timesteps(images.shape[0]).to(device)
            noisy_latent_images, noises = ddpm_sampler.add_noise(images, timesteps)
            time_embeddings = timesteps_to_time_emb(timesteps)

            with torch.no_grad():
                last_decoder_noise = model(noisy_latent_images, text_embeddings, time_embeddings)
            final_output = diffusion.final.to(device)
            predicted_noise = final_output(last_decoder_noise).to(device)
            loss = mse(noises, predicted_noise)
            loss.backward()

            if (batch_idx + 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            ema.step_ema(ema_model, model)
            progress_bar.set_postfix(MSE=loss.item())
            logger.add_scalar("MSE", loss.item(), global_step=epoch * len_train + batch_idx)

        os.makedirs(os.path.join("models", args.run_name), exist_ok=True)
        torch.save(model.state_dict(), os.path.join("models", args.run_name, f"stable_diffusion.ckpt"))
        torch.save(optimizer.state_dict(), os.path.join("models", args.run_name, f"optim.pt"))
© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号