从0开始深度学习:LeNet卷积神经网络详解与实践
创作时间:
作者:
@小白创作中心
从0开始深度学习:LeNet卷积神经网络详解与实践
引用
CSDN
1.
https://m.blog.csdn.net/m0_53115174/article/details/143608566
LeNet是最早的卷积神经网络之一,由Yann LeCun等人在1990年代提出,并以其名字命名。最初,LeNet被设计用于手写数字识别,最著名的应用是在美国的邮政系统中识别手写邮政编码。LeNet架构的成功证明了卷积神经网络在解决实际问题中的有效性,为后续更复杂、更强大的CNN模型的发展奠定了基础。
LeNet网络结构
LeNet的网络结构如下:
- 两个卷积层,分别使用5x5的卷积核
- 两个平均池化层,步长为2
- 三个全连接层
使用PyTorch实现该结构的代码如下:
import torch
from torch import nn
net = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5, padding=2),
nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5),
nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
nn.Linear(120, 84), nn.Sigmoid(),
nn.Linear(84, 10)
)
模型训练
为了检测LeNet-5在Fashion-MNIST数据集上的表现,我们使用以下代码进行训练:
import torchvision
from torch.utils import data
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
# 数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) # 标准化到[-1, 1]区间,加快计算
])
# 加载Fashion-MNIST数据集
train_dataset = datasets.FashionMNIST(root='D:/DL_Data/', train=True, download=False, transform=transform)
test_dataset = datasets.FashionMNIST(root='D:/DL_Data/', train=False, download=False, transform=transform)
train_loader = data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)
# 训练参数
lr, num_epochs = 0.9, 10
# 训练函数
def train(net, train_iter, test_iter, num_epochs, lr):
def init_weights(m):
if type(m) == nn.Linear or type(m) == nn.Conv2d:
nn.init.xavier_uniform_(m.weight)
net.apply(init_weights)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net.to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss()
train_acc_list, test_acc_list = [], []
for epoch in range(num_epochs):
net.train()
for X, y in train_iter:
optimizer.zero_grad()
X, y = X.to(device), y.to(device)
y_hat = net(X)
l = loss(y_hat, y)
l.backward()
optimizer.step()
train_acc = evaluate_acc(net, train_iter, device)
test_acc = evaluate_acc(net, test_iter, device)
train_acc_list.append(train_acc)
test_acc_list.append(test_acc)
print(f"epoch: {epoch+1}, train_acc: {train_acc:.3f}, test_acc: {test_acc:.3f}")
return train_acc_list, test_acc_list
# 评估函数
def evaluate_acc(net, data_iter, device):
net.eval()
metric = [0, 0]
with torch.no_grad():
for X, y in data_iter:
X, y = X.to(device), y.to(device)
y_hat = net(X)
metric[0] += (y_hat.argmax(dim=1) == y).sum().item()
metric[1] += y.numel()
return metric[0] / metric[1]
train_acc_list, test_acc_list = train(net, train_loader, test_loader, num_epochs, lr)
结果分析
训练完成后,我们可以绘制训练和测试准确率的折线图:
epochs = range(1, num_epochs + 1)
plt.plot(epochs, train_acc_list, 'b', label='Training Accuracy')
plt.plot(epochs, test_acc_list, 'r', label='Testing Accuracy')
plt.title('Training and Testing Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
从图像可以看出,准确率还没有稳定,说明还有提升空间,可以添加epoch继续训练以获得更准的分类效果。
热门推荐
朱震亨与他的医学派别
中医流派之温补学派
劳动争议仲裁调解法的主要内容是什么?
浴缸材质分类
俄语学习计划小白如何制定
新型水凝胶敷料为糖尿病足溃疡治疗带来新希望
红茶、绿茶、乌龙茶都适合搭配什么茶点?20种茶点特点和适配茶类
骨质增生的原因及预防管理
民事诉讼一审时间是多久 有什么流程
如何挑选合适的灯具?怎样根据不同需求选择灯具?
关于妇女节,你了解多少?
荆州古城的三个文化符号
笔记本选购指南:睿频与主频的全面解析及选择建议
赣州十大名小吃,赣州舌尖上的传奇
医保缴费年限“新规定”!多地已正式执行,参保职工要受益了
三国杀玩法全攻略:从入门到精通的十一种游戏模式详解
很累很烦很压抑怎么办
刑事辩护律师如何制定审判策略
投资理财知识全面解析 —— 掌握投资理财的基础与技巧
维生素D3注射液的作用及功效
如何提高社会治理数据的透明度?
测量不确定度基础知识
如何解读他人的表情
绿茶和咖啡,哪个对身体更好?
职业病的类型有哪些?如何判断是否属于职业病?
大众汽车发展历史
揭秘洗涤剂背后的科学!洗衣粉、洗衣液、洗衣凝珠,选对了吗?
北京工资收入分布状况探究:一个城市的工资差异揭秘
Excel绘制弧形座位图的详细教程
藏不住了!广西人爱吃酸嘢,原来还有这些好处