ResNet结构详解:从核心概念到代码实现
ResNet结构详解:从核心概念到代码实现
ResNet(残差网络)是深度学习领域中最具影响力的模型之一,由微软研究院的Kaiming He等人于2015年提出。它通过引入残差学习的概念,成功解决了深层网络训练中的梯度消失问题,并在多个视觉识别任务中取得了当时最佳的性能。本文将从核心概念、基本结构、优势以及代码示例四个方面,深入浅出地介绍ResNet的原理和应用。
一、ResNet的核心概念
1.残差学习
ResNet的核心思想是残差学习,即网络学习的是输入和输出之间的残差(即差异),而不是直接学习映射关系。这种设计使得网络更容易优化,特别是在深层网络中。
2.残差块(Residual Blocks)
ResNet由多个残差块组成,每个残差块包含两条路径:一条是卷积层的堆叠,另一条是恒等连接(Identity Connection)。这种结构允许梯度直接流向前面的层,从而缓解了深层网络训练中的梯度消失问题。
3.恒等连接
恒等连接允许输入直接跳过一些层的输出,然后与这些层的输出相加。这有助于解决深层网络训练中的梯度消失问题。
二、ResNet的基本结构
卷积层
输入数据首先通过一个卷积层,通常是一个具有相同填充的3x3卷积核,以减少数据的空间维度。残差块
每个残差块包含多个卷积层(通常是两个或更多),每个卷积层后面跟着一个批次归一化层和ReLU激活函数。恒等快捷连接
如果残差块的输入和输出的维度不一致,使用1x1卷积核进行维度匹配,然后与块内的卷积层输出相加。层叠残差块
多个残差块按顺序堆叠,形成更深的网络结构。全局平均池化
在残差块堆叠之后,使用全局平均池化层将特征图转换为一维特征向量。全连接层
最后,一个或多个全连接层用于分类任务,输出最终的类别概率。
下图是不同层数的残差网络的架构:
三、ResNet的优势
解决梯度消失问题
通过恒等连接,ResNet允许梯度直接流向前面的层,从而缓解了深层网络训练中的梯度消失问题。简化网络设计
残差块的引入简化了网络的设计,使得可以轻松地构建更深的网络。提高模型性能
ResNet在多个视觉识别任务上取得了优异的性能,证明了其有效性。
四、代码示例
使用迁移学习加载resnet18网络,冻结大部分参数,微调需要修改的部分。
import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
'''将resnet18模型迁移到食物分类项目中'''
# 即调用了resnet18网络,又使用了训练好的模型
resnet_model = models.resnet18(weights = models.ResNet18_Weights.DEFAULT)
# weights=models.ResNet18_Weights.DEFAULT表示使用在ImageNet数据集上预先训练好的权重来初始化模型参数,可进入源代码查看
for param in resnet_model.parameters():
param.requires_grad = False
# 模型所有参数(即权重和偏差)的requires_grad属性设置为False, 从而冻结所有模型参数
# 使得在反向传播过程中不会计算它们的梯度,以此减少模型的计算量,提高推理速度。
# 修改全连接层适应当前项目
in_features = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(in_features, 20)
# 创建一个全连接层 输入特征为in_features,输出为20
params_to_update = [] # 保存需要训练的参数,仅仅包含全连接层的参数
for param in resnet_model.parameters():
if param.requires_grad == True: #默认重新赋值后requires_grad就不是false
params_to_update.append(param)