PyTorch中torchvision数据集和dataloader使用详解
PyTorch中torchvision数据集和dataloader使用详解
本文是一篇关于PyTorch中torchvision数据集和dataloader使用的详细教程。文章从数据集的下载、读取,到结合transform进行数据预处理,再到dataloader的各种参数设置,层层递进,内容详尽且实用,非常适合PyTorch和深度学习的初学者阅读。
数据集简介
PyTorch官方提供了丰富的数据集资源,其中CIFAR10是一个常用的数据集,主要用于图像分类任务。CIFAR10数据集包含60000张32x32的彩色图像,分为10个类别,每个类别有6000张图像。数据集分为50000张训练图像和10000张测试图像。
程序中下载数据集
使用CIFAR10数据集时,可以通过torchvision.datasets模块进行下载和加载。以下是一个简单的示例代码:
import torchvision
# 训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
# 测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
运行程序后,可以看到数据集开始下载。如果下载速度较慢,可以尝试直接点击蓝色的下载链接进行下载,或者将下载链接放在迅雷等下载工具中进行下载。CIFAR10数据集大小约为100多兆,相对较小,适合用于练手。
读取数据集
下载完成后,可以通过索引方式读取数据集中的图像和标签。以下是一个读取测试数据集的示例代码:
import torchvision
# 训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
# 测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)
# 查看数据集
print(test_set[0]) # (<PIL.Image.Image image mode=RGB size=32x32 at 0x27E68307FA0>, 3)
# 图片组成部分 第一个是图片 第二个是target
print(test_set.classes) # ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
img, target = test_set[0]
print(img) # <PIL.Image.Image image mode=RGB size=32x32 at 0x27E21B660E0>
print(target) # 3 实际上对应的就是test_set.classes的第4个也就是cat,因为它是从0开始算的,说明这张图片对应的是猫
print(test_set.classes[target]) # cat
# 展示图片 PIL的Image可以直接show
img.show()
数据集中的图片大小为32x32像素,可以通过PIL库直接展示。
结合transform进行读取数据集
在实际应用中,通常需要对图像数据进行预处理,例如转换为Tensor类型。可以通过torchvision.transforms模块实现这一功能。以下是一个使用transform读取数据集的示例代码:
import torchvision
from torch.utils.tensorboard import SummaryWriter
# 定义transform
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
# 训练数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
# 测试数据集
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
# 使用tensorboard展示图片
writer = SummaryWriter("p10")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()
运行程序后,可以通过tensorboard查看转换后的图像数据。
dataloader
在深度学习中,通常需要将数据集加载到神经网络中进行训练。dataloader的作用就是将数据集分批加载到神经网络中,可以设置每次加载的数据量、是否打乱数据顺序等参数。以下是一个使用dataloader的示例代码:
import torchvision.datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# 准备测试数据集
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
# 定义dataloader
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
# 使用tensorboard展示图片
writer = SummaryWriter("dataloader")
step = 0
for data in test_loader:
imgs, targets = data
writer.add_images("test_data", imgs, step)
step += 1
writer.close()
dataloader参数详解
- batch_size:每次从数据集中取数据的数量。
- shuffle:是否打乱数据顺序。为True时,每次加载的数据顺序会随机打乱;为False时,保持原始顺序。
- num_workers:用于数据加载的进程数。默认为0,表示使用主进程进行加载。在Windows系统中,num_workers大于0时可能会出现一些错误。
- drop_last:对于不能整除batch_size的数据,是否丢弃剩余数据。为True时,会丢弃剩余数据;为False时,会保留剩余数据。
参数测试
batch_size的参数测试
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
drop_last的参数测试
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
shuffle的参数测试
- shuffle为False时,两次加载的数据顺序相同:
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=False, num_workers=0, drop_last=True)
for epoch in range(2):
step = 0
for data in test_loader:
imgs, targets = data
writer.add_images("Epoch:{}".format(epoch), imgs, step)
step += 1
- shuffle为True时,两次加载的数据顺序会随机打乱:
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
全部代码
import torchvision.datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# 准备测试数据集
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
# 定义dataloader
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
# 使用tensorboard展示图片
writer = SummaryWriter("dataloader")
for epoch in range(2):
step = 0
for data in test_loader:
imgs, targets = data
writer.add_images("Epoch:{}".format(epoch), imgs, step)
step += 1
writer.close()
通过以上代码,可以直观地看到不同参数设置对数据加载的影响,帮助更好地理解dataloader的使用。