问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

PyTorch中torchvision数据集和dataloader使用详解

创作时间:
作者:
@小白创作中心

PyTorch中torchvision数据集和dataloader使用详解

引用
CSDN
1.
https://blog.csdn.net/Q20011102/article/details/141395825

本文是一篇关于PyTorch中torchvision数据集和dataloader使用的详细教程。文章从数据集的下载、读取,到结合transform进行数据预处理,再到dataloader的各种参数设置,层层递进,内容详尽且实用,非常适合PyTorch和深度学习的初学者阅读。

数据集简介

PyTorch官方提供了丰富的数据集资源,其中CIFAR10是一个常用的数据集,主要用于图像分类任务。CIFAR10数据集包含60000张32x32的彩色图像,分为10个类别,每个类别有6000张图像。数据集分为50000张训练图像和10000张测试图像。

程序中下载数据集

使用CIFAR10数据集时,可以通过torchvision.datasets模块进行下载和加载。以下是一个简单的示例代码:

import torchvision

# 训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)

# 测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)

运行程序后,可以看到数据集开始下载。如果下载速度较慢,可以尝试直接点击蓝色的下载链接进行下载,或者将下载链接放在迅雷等下载工具中进行下载。CIFAR10数据集大小约为100多兆,相对较小,适合用于练手。

读取数据集

下载完成后,可以通过索引方式读取数据集中的图像和标签。以下是一个读取测试数据集的示例代码:

import torchvision

# 训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)

# 测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)

# 查看数据集
print(test_set[0])  # (<PIL.Image.Image image mode=RGB size=32x32 at 0x27E68307FA0>, 3)
# 图片组成部分 第一个是图片 第二个是target
print(test_set.classes)  # ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

img, target = test_set[0]
print(img)  # <PIL.Image.Image image mode=RGB size=32x32 at 0x27E21B660E0>
print(target)  # 3   实际上对应的就是test_set.classes的第4个也就是cat,因为它是从0开始算的,说明这张图片对应的是猫
print(test_set.classes[target])  # cat

# 展示图片  PIL的Image可以直接show
img.show()

数据集中的图片大小为32x32像素,可以通过PIL库直接展示。

结合transform进行读取数据集

在实际应用中,通常需要对图像数据进行预处理,例如转换为Tensor类型。可以通过torchvision.transforms模块实现这一功能。以下是一个使用transform读取数据集的示例代码:

import torchvision
from torch.utils.tensorboard import SummaryWriter

# 定义transform
dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

# 训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)

# 测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)

# 使用tensorboard展示图片
writer = SummaryWriter("p10")
for i in range(10):
    img, target = test_set[i]
    writer.add_image("test_set", img, i)
writer.close()

运行程序后,可以通过tensorboard查看转换后的图像数据。

dataloader

在深度学习中,通常需要将数据集加载到神经网络中进行训练。dataloader的作用就是将数据集分批加载到神经网络中,可以设置每次加载的数据量、是否打乱数据顺序等参数。以下是一个使用dataloader的示例代码:

import torchvision.datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 准备测试数据集
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())

# 定义dataloader
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)

# 使用tensorboard展示图片
writer = SummaryWriter("dataloader")
step = 0
for data in test_loader:
    imgs, targets = data
    writer.add_images("test_data", imgs, step)
    step += 1
writer.close()

dataloader参数详解

  • batch_size:每次从数据集中取数据的数量。
  • shuffle:是否打乱数据顺序。为True时,每次加载的数据顺序会随机打乱;为False时,保持原始顺序。
  • num_workers:用于数据加载的进程数。默认为0,表示使用主进程进行加载。在Windows系统中,num_workers大于0时可能会出现一些错误。
  • drop_last:对于不能整除batch_size的数据,是否丢弃剩余数据。为True时,会丢弃剩余数据;为False时,会保留剩余数据。

参数测试

batch_size的参数测试

test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)

drop_last的参数测试

test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

shuffle的参数测试

  • shuffle为False时,两次加载的数据顺序相同:
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=False, num_workers=0, drop_last=True)

for epoch in range(2):
    step = 0
    for data in test_loader:
        imgs, targets = data
        writer.add_images("Epoch:{}".format(epoch), imgs, step)
        step += 1
  • shuffle为True时,两次加载的数据顺序会随机打乱:
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

全部代码

import torchvision.datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# 准备测试数据集
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())

# 定义dataloader
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

# 使用tensorboard展示图片
writer = SummaryWriter("dataloader")
for epoch in range(2):
    step = 0
    for data in test_loader:
        imgs, targets = data
        writer.add_images("Epoch:{}".format(epoch), imgs, step)
        step += 1
writer.close()

通过以上代码,可以直观地看到不同参数设置对数据加载的影响,帮助更好地理解dataloader的使用。

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号