问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

绘制近似线性可分支持向量机的分类边界和支持向量

创作时间:
作者:
@小白创作中心

绘制近似线性可分支持向量机的分类边界和支持向量

引用
CSDN
1.
https://blog.csdn.net/u013172930/article/details/143271435

支持向量机(SVM)是一种常用的机器学习算法,广泛应用于分类和回归问题。在二分类问题中,SVM通过寻找最大间隔超平面来实现分类。本文将介绍一个用于绘制SVM分类边界和支持向量的Python函数,帮助读者直观理解SVM的工作原理。

下面这段代码定义了一个函数plot_classifier,用于可视化支持向量机(SVM)的分类结果、支持向量和决策边界。具体来说,它绘制了两个训练集的点,标记了支持向量,并显示了分类器的决策边界和间隔边界。让我们逐步解释每一部分的功能:

1. 绘制训练数据点

plt.plot(X1_train[:,0], X1_train[:,1], "ro")
plt.plot(X2_train[:,0], X2_train[:,1], "go")
  • X1_trainX2_train:分别是属于两个不同类别的训练数据集。假设X1_train代表第一类的数据,X2_train代表第二类的数据。每个数据点有两个特征,所以它们是二维的。
  • 使用plt.plot绘制训练数据:
  • "ro"用红色圆圈绘制第一类训练数据点。
  • "go"用绿色圆圈绘制第二类训练数据点。

2. 绘制支持向量

plt.scatter(clf.spv[:,0], clf.spv[:,1],
            s=100, c="y", edgecolors="b", label="support vector")
  • clf.spv:这是训练好的 SVM 模型clf中的支持向量(spv),包含所有在训练过程中被识别为支持向量的样本点。
  • 使用plt.scatter绘制支持向量:
  • s=100:设置支持向量的大小。
  • c="y":支持向量的颜色设为黄色。
  • edgecolors="b":支持向量的边框颜色为蓝色。
  • label="support vector":用于图例标记支持向量。

3. 创建网格数据

X1, X2 = np.meshgrid(np.linspace(-4, 4, 50), np.linspace(-4, 4, 50))
X = np.array([[x1, x2] for x1, x2 in zip(np.ravel(X1), np.ravel(X2))])
  • np.meshgrid:生成一个二维的网格数据,这些网格点用于绘制分类边界。np.linspace(-4, 4, 50)表示生成从 -4 到 4 的 50 个等间隔点,X1X2分别对应网格的横轴和纵轴坐标。
  • np.ravel(X1)np.ravel(X2):将网格点展平成一维数组,便于后续将每个网格点的坐标组合。
  • X:将网格点(x1, x2)组合成二维数组,作为分类器的输入,计算这些点的分类结果。

4. 计算网格上的分类结果

Z = clf.project(X).reshape(X1.shape)
  • clf.project(X):通过分类器clf对网格上的每个点X进行分类,返回的结果是分类器的决策函数值f(x) = w^T x + b,用于确定分类边界。
  • Z:分类结果是一个形状与网格X1X2相同的二维数组,用于绘制等高线图。

5. 绘制决策边界和间隔边界

plt.contour(X1, X2, Z, [0.0], colors='k', linewidths=1, origin='lower')
plt.contour(X1, X2, Z + 1, [0.0], colors='grey', linewidths=1, origin='lower')
plt.contour(X1, X2, Z - 1, [0.0], colors='grey', linewidths=1, origin='lower')
  • plt.contour:用于绘制等高线图,显示分类器的决策边界和间隔边界。
  • Z:分类结果,其中Z = 0表示决策边界(超平面),对应分类函数f(x) = 0
  • Z + 1Z - 1:分别表示间隔边界f(x) = 1f(x) = -1
  • colors='k':决策边界的颜色为黑色。
  • colors='grey':间隔边界的颜色为灰色。
  • linewidths=1:设置线条宽度。

6. 显示图例和绘图

plt.legend()
plt.show()
  • plt.legend():显示图例,标注支持向量。
  • plt.show():展示完整的绘图结果。

总结

  • 输入数据点:函数通过plt.plot绘制两个类别的训练数据点,红色代表第一类,绿色代表第二类。
  • 支持向量:使用plt.scatter绘制支持向量,并用黄色标记、蓝色边框强调支持向量的重要性。
  • 分类边界:通过plt.contour绘制决策边界(黑色)和间隔边界(灰色)。
  • 网格点预测:通过在二维网格上的预测,确定分类器的决策区域,并在图中可视化。

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号