问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

使用PyTorch训练一个手写数字识别模型(MNIST)

创作时间:
作者:
@小白创作中心

使用PyTorch训练一个手写数字识别模型(MNIST)

引用
CSDN
1.
https://blog.csdn.net/m0_67724631/article/details/138806650

本文将介绍如何使用PyTorch训练一个手写数字识别模型(MNIST)。MNIST数据集是一个经典的机器学习基准数据集,包含大约60,000个训练样本和10,000个测试样本,每个样本都是一个28x28像素的手写数字图像,标签为0到9。

准备工作

首先导入必要的库,并定义一个简单的神经网络结构。这个神经网络由三个线性层组成,每个线性层之间使用ReLU激活函数进行激活。最后一层使用log softmax作为输出。类似下图

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

初始化神经网络、损失函数和优化器。

net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

定义数据预处理的转换。将图像转换为PyTorch张量,并对图像进行标准化处理。

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号