梯度下降过程可视化工具及案例分析
创作时间:
作者:
@小白创作中心
梯度下降过程可视化工具及案例分析
引用
CSDN
1.
https://blog.csdn.net/weixin_43589323/article/details/137237632
梯度下降是机器学习和深度学习中常用的一种优化算法,用于寻找函数的最小值。为了更直观地理解梯度下降的过程,本文将介绍一个自定义的梯度下降可视化工具,并通过多个案例展示不同参数设置下的收敛效果。
梯度下降可视化效果展示
函数 f(x) = x²
我们首先观察最简单的二次函数 f(x) = x² 在不同学习率下的梯度下降过程。
学习率 η = 0.1(较慢的收敛)
学习率 η = 0.5(比较好的收敛)
学习率 η = 0.9(振荡的收敛)
从上述结果可以看出,学习率的选择对梯度下降的收敛速度和稳定性有重要影响。过小的学习率会导致收敛速度过慢,而过大的学习率则可能导致振荡。
函数 f(x) = 0.15πcos(0.15πx)
接下来,我们观察一个更复杂的函数 f(x) = 0.15πcos(0.15πx) 在不同参数设置下的梯度下降过程。
学习率 η = 0.1,起始位置 x₀ = 15
学习率 η = 0.2,起始位置 x₀ = -8
学习率 η = 3,起始位置 x₀ = 15
这些案例进一步说明了学习率和初始点对梯度下降过程的影响。在复杂函数中,选择合适的参数组合对于快速收敛到全局最小值至关重要。
总结
梯度下降过程受到起始值和学习率(迭代步长)的影响,为了能够让收敛过程更加快速准确,需要认真对待初始化过程。
可视化工具代码实现
下面是实现梯度下降过程可视化的完整代码:
import torch
from torch import Tensor
from typing import Callable
import matplotlib.pyplot as plt
from matplotlib_inline import backend_inline
def use_svg_display():
"""Use the svg format to display a plot in Jupyter."""
backend_inline.set_matplotlib_formats('svg')
def set_figsize(figsize=(4.5, 3.5)):
"""Set the figure size for matplotlib."""
use_svg_display()
plt.rcParams['figure.figsize'] = figsize
class GradientDescentVisualization:
def __init__(self, func: Callable,
steps: int=10,
eta: float=0.1,
init_point: Tensor=None):
"""
Initialize the gradient descent visualization tool.
Args:
func (Callable): Objective function
steps (int, optional): Total steps to perform gradient descent. Defaults to 10.
eta (float, optional): Learning rate. Defaults to 0.1.
init_point (Tensor, optional): Start point of the progress. Defaults to None.
"""
self.func = func
self.steps = steps
self.eta = eta
self.init_point = init_point
if self.init_point is None: # If not given, initialize with normal distribution
self.init_point = torch.randn(1, requires_grad=True)
assert self.init_point.requires_grad is True
def evolution(self):
"""Perform gradient descent progress."""
x = self.init_point
eta = self.eta
func = self.func
steps = self.steps
# Record evolution of x
x_evolution = [x.data.item()]
for _ in range(steps):
y = func(x) # Compute output
y.backward() # Backward
x.data -= eta * x.grad.data # Compute next point
x.grad.data.zero_() # Clear grad
x_evolution.append(x.data.item()) # Record evolution
self.x_evolution = x_evolution
def show_trace(self, bound: float=None):
"""Plot gradient descent progress."""
x_evolution = self.x_evolution
f = self.func
# Get bound of x_evolution
if bound is None:
bound = max(abs(min(x_evolution)), abs(max(x_evolution)))
f_line = torch.arange(-bound, bound, 0.01)
set_figsize()
# Plot graph of objective function
plt.plot(f_line, [f(x) for x in f_line], '-', c='b')
# Plot the gradient descent progress
points = list(zip(x_evolution, [f(x) for x in x_evolution]))
# Use annotate to plot the arrow
for i in range(len(points) - 1):
plt.gca().annotate("", xy=points[i+1], xytext=points[i],
arrowprops=dict(arrowstyle="->", lw=1.0, fc="red", ec="red"))
# Plot the start position of the gradient descent progress
ax1 = plt.plot(points[0][0], points[0][1], 'ro')
# Plot the final position of the gradient descent progress
ax2 = plt.plot(points[-1][0], points[-1][1], 'go')
plt.legend([ax1[0], ax2[0]], ['start', 'end'])
plt.title(f'$\eta$ = {self.eta:.3f}')
plt.xlim(-bound, bound)
plt.xlabel('x')
plt.ylabel('func')
plt.grid()
plt.show()
使用示例
以下是如何使用上述工具进行梯度下降过程可视化的示例:
import torch
c = torch.tensor(0.15 * torch.pi)
def func(x): # Objective function
return x * torch.cos(c * x)
steps = 10 # Number of iterations
eta = 0.2 # Learning rate
init_point = torch.tensor([-8.0], requires_grad=True) # Initial position
# Initialize the visualization tool
gdv = GradientDescentVisualization(func, steps, eta, init_point)
# Perform gradient descent
gdv.evolution()
# Visualize the process
gdv.show_trace()
通过调整 func
、steps
、eta
和 init_point
的参数,可以观察不同函数和参数设置下的梯度下降过程。
热门推荐
常温超导破局!薛其坤团队引爆千亿产业重构风暴
波长对光的衍射现象实验研究
冰岛虞美人种子种植指南
园林可以用虞美人吗 花期是多久 花语是什么
我国的黄酒生产特点
宝宝爱咬手指、啃笔怎么办?这些方法很实用
德国工作时间揭秘:效率与平衡的完美结合
十三烷的化学性质及用途
人工智能时代的人才新标准:创造性与人机协同能力的重要性
如何进行软件产品的数据隐私合规性检查
国产CPU电脑性能实测:海光3250、麒麟990与龙芯3A5000深度对比!
尿液四大变化预警痛风风险,专家详解关键检查指标
天麻西洋参可以一起吃吗
辛弃疾《满江红》:豪放激昂中的家国情怀
年龄相差较大仍可幸福婚姻:如何克服婚姻中的代沟?
新人入团队如何破冰
如何高效制作美观实用的表格:工具、步骤与设计技巧解析
项目管理表格如何制成
游戏建模师工资揭秘:薪资水平、影响因素与行业前景
工业数据建模的薪资会随着工作经验增加吗?
如何提高网盘下载速度,让文件下载更快捷?
公交车当校车合法吗
Dev-C++6.3安装详细教程及中文乱码解决办法
猫咪饮食搭配大全——为你的爱宠提供全面营养
让家猫开心的 13 种方法
为珊瑚搭建“避难所”:广西北海涠洲岛的珊瑚礁保护与修复之路
珊瑚礁的保护措施如何实施以维护生态平衡?这些措施有哪些具体效果?
如何高效管理和存储家庭照片:数字化时代的最佳实践
巧用废旧iPad!轻松改造成家庭电子相册
红毛丹的三种吃法