问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

交替优化ADMM:受限问题、对抗网络和鲁棒模型研究(Matlab代码实现)

创作时间:
作者:
@小白创作中心

交替优化ADMM:受限问题、对抗网络和鲁棒模型研究(Matlab代码实现)

引用
CSDN
1.
https://m.blog.csdn.net/m0_58086403/article/details/144707791

交替方向乘子法(ADMM)是一种用于求解具有线性约束的优化问题的算法,在机器学习和信号处理等领域有着广泛的应用。本文将介绍ADMM在受限问题、对抗网络和鲁棒模型研究中的应用,并提供具体的Matlab代码实现。

1. 概述

数据驱动的机器学习方法在许多工业应用和学术任务中取得了令人印象深刻的性能。机器学习方法通常分为两个阶段:从大规模样本训练模型,以及在部署模型后对新样本进行推断。现代模型的训练依赖于解决涉及非凸、不可微目标函数和约束的困难优化问题,有时速度较慢,通常需要专业知识来调整超参数。虽然推断比训练快得多,但通常对实时应用来说速度还不够快。我们关注可以在训练中被表述为极小极大问题的机器学习问题,并研究交替优化方法作为快速、可扩展、稳定和自动化的求解器。

首先,我们关注经典凸和非凸优化中约束问题的交替方向乘子法(ADMM)。一些流行的机器学习应用包括稀疏和低秩模型、正则化线性模型、总变差图像处理、半定规划和共识分布式计算。我们提出了自适应ADMM(AADMM),这是一个完全自动化的求解器,通过调整ADMM中唯一的自由参数实现快速实际收敛。我们进一步自动化了几个ADMM的变体(放松的ADMM、多块ADMM和共识ADMM),并证明了适用于具有不同参数的ADMM变体的收敛速率保证。我们发布了超过十种应用的快速实现,并使用每个应用的几个基准数据集验证了效率。其次,我们关注生成对抗网络(GAN)的极小极大问题。我们应用预测步骤来稳定随机交替方法,用于训练GAN,并展示了基于GAN的损失对图像处理任务的优势。我们还提出了基于GAN的知识蒸馏方法,用于训练小型神经网络以加速推断,并在经验上研究了加速和准确性之间的权衡。第三,我们展示了对鲁棒模型的对抗训练的初步结果。我们研究了用于攻击和防御通用扰动的快速算法,然后探讨了提高鲁棒性的网络架构。

交替方向法是解决线性约束下的单调可分变分不等式问题的吸引人方法之一。在应用经验中发现,迭代次数与线性约束方程组的惩罚参数密切相关。在原始方法中,惩罚参数是一个常数,但在本文中,我们提出了一种修改后的交替方向法,根据迭代信息每次调整惩罚参数。初步的数值测试表明,自适应调整技术在实践中是有效的。

包含:

  • 自适应松弛(ARADMM)
  • 自适应共识ADMM(ACADMM)
  • 视觉子类别低秩最小二乘
  • 自适应ADMM(AADMM)
  • 非凸问题的AADMM
  • 自适应多块ADMM的代码包含在此软件包中

我们还提供了基准方法的实现,如基本ADMM、快速(Nestrov)ADMM、残差平衡和归一化残差平衡。

具体应用包括:

  • 带有弹性网(l2 + l1)正则化器的线性回归
  • 带有稀疏(l1/l0)正则化器的线性回归
  • 带有(l1/l2)正则化器的逻辑回归
  • 基础追踪
  • 低秩最小二乘
  • 鲁棒PCA(RPCA)
  • 二次规划(QP)
  • 半定规划(SDP)
  • 支持向量机(SVM)
  • 1D/2D去噪与总变差正则化器
  • 图像去噪/恢复/去模糊与总变差正则化器
  • 分布式共识问题:逻辑回归
  • 分布式共识问题:线性回归
  • 典型的非凸问题:特征值问题
  • 典型的非凸问题:相位恢复

2. 运行结果

部分代码:

%% 
% ADMM 
opts.adp_flag = 0; %fix tau, no adaptation 
[sol1,outs1] = aadmm_lrls(D, c, np, lam1, lam2, opts); 
fprintf('vanilla ADMM complete after %d iterations!\n', outs1.iter); 
% adaptive ADMM 
opts.adp_flag = 5; %AADMM with spectral penalty 
[sol2,outs2] = aadmm_lrls(D, c, np, lam1, lam2, opts); 
fprintf('adaptive ADMM complete after %d iterations!\n', outs2.iter); 
% Nesterov ADMM 
opts.adp_flag = 2; % Nesterove ADMM 
[sol3,outs3] = aadmm_lrls(D, c, np, lam1, lam2, opts); 
fprintf('Nesterove ADMM complete after %d iterations!\n', outs3.iter); 
% adaptive ADMM baseline: residual balance 
opts.adp_flag = 3; %residual balance 
[sol4,outs4] = aadmm_lrls(D, c, np, lam1, lam2, opts); 
fprintf('RB ADMM complete after %d iterations!\n', outs4.iter); 
% adaptive ADMM baseline: normalized residual balance 
opts.adp_flag = 4; %normalized residual balance 
[sol5, outs5] = aadmm_lrls(D, c, np, lam1, lam2, opts); 
fprintf('NRB ADMM complete after %d iterations!\n', outs5.iter); 
%% 
legends = {'Vanilla ADMM', 'Fast ADMM', 'Residual balance', 'Normalized RB', 'Adaptive ADMM'}; 
figure, 
semilogy(outs1.tols, '-.g'), 
hold, 
semilogy(outs3.tols, '-.r'); 
semilogy(outs4.tols, '--m'); 
semilogy(outs5.tols, '--', 'Color',[0.7 0.2 0.2]); 
semilogy(outs2.tols, 'b'); 
ylabel('Relative residual', 'FontName','Times New Roman'); 
ylim([10^(-3) 10]); 
xlabel('Iteration', 'FontName','Times New Roman'); 
legend(legends, 'FontName','Times New Roman'); 

运行结果图:

3. 参考文献

文章中一些内容引自网络,会注明出处或引用为参考文献,难免有未尽之处,如有不妥,请随时联系删除。

4. Matlab代码、数据、文章

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号