问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

从零开始人工智能Matlab案例-手写数字识别

创作时间:
作者:
@小白创作中心

从零开始人工智能Matlab案例-手写数字识别

引用
CSDN
1.
https://blog.csdn.net/weixin_31268759/article/details/145439196

手写数字识别是机器学习领域最经典的入门案例之一,MNIST数据集作为手写数字识别的标准数据集,包含了60000个训练样本和10000个测试样本。本教程将使用MATLAB的Deep Learning Toolbox,从零开始实现一个简单的卷积神经网络(CNN),完成对手写数字的识别任务。

案例目标

使用MATLAB的Deep Learning Toolbox训练一个神经网络,识别手写数字(0-9)。

步骤 1:准备数据

  1. 加载 MNIST 数据集

MATLAB 内置了 MNIST 数据集,可以直接加载。

% 加载训练数据和测试数据
[XTrain, YTrain] = digitTrain4DArrayData;
[XTest, YTest] = digitTest4DArrayData;

% 查看数据维度
disp(size(XTrain)); % 28x28x1x60000(28x28像素,单通道,6万张训练图)

% 随机显示25张训练图片
figure;
perm = randperm(numel(YTrain), 25);
for i = 1:25
    subplot(5,5,i);
    imshow(XTrain(:,:,:,perm(i)));
    title(char(YTrain(perm(i)))); % 显示标签
end

步骤 2:构建神经网络

  1. 定义网络结构

构建一个简单的卷积神经网络(CNN)。

layers = [
    imageInputLayer([28 28 1])  % 输入层(28x28x1的灰度图)
    convolution2dLayer(3, 8, 'Padding', 'same') % 卷积层(3x3滤波器,8个通道)
    batchNormalizationLayer    % 批归一化
    reluLayer                  % ReLU激活函数
    maxPooling2dLayer(2, 'Stride', 2) % 最大池化层(2x2窗口,步长2)
    convolution2dLayer(3, 16, 'Padding', 'same') % 第二层卷积(16个通道)
    batchNormalizationLayer
    reluLayer
    maxPooling2dLayer(2, 'Stride', 2)
    fullyConnectedLayer(10)    % 全连接层(10个输出对应0-9)
    softmaxLayer               % Softmax分类
    classificationLayer];      % 分类输出层

analyzeNetwork(layers); % 可视化网络结构

步骤 3:训练模型

  1. 设置训练参数
options = trainingOptions('sgdm', ...        % 使用随机梯度下降
    'InitialLearnRate', 0.01, ...            % 初始学习率
    'MaxEpochs', 5, ...                      % 训练5轮
    'Shuffle', 'every-epoch', ...            % 每轮打乱数据
    'ValidationData', {XTest, YTest}, ...    % 验证集
    'Verbose', true, ...                     % 显示训练过程
    'Plots', 'training-progress');           % 绘制训练曲线
  1. 开始训练
net = trainNetwork(XTrain, YTrain, layers, options);

步骤 4:测试模型

  1. 预测测试集
YPred = classify(net, XTest); % 对测试集分类
  1. 计算准确率

accuracy = sum(YPred == YTest) / numel(YTest);
disp(['测试集准确率: ', num2str(accuracy * 100), '%']);
% 预期结果:约95%以上(受训练轮数和网络复杂度影响)
  1. 查看混淆矩阵
figure;
confusionchart(YTest, YPred);

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号