PyTorch数据处理:Dataset类详解与实战
PyTorch数据处理:Dataset类详解与实战
PyTorch数据处理:Dataset类详解与实战
在PyTorch中,虽然DataLoader
是数据模块的核心,但用户最常修改和直接接触的部分是Dataset
。本节将深入分析Dataset
的作用,并通过三个案例展示如何编写自定义的Dataset
来读取不同类型的数据集。
Dataset的功能
PyTorch提供的torch.utils.data.Dataset
类是一个抽象基类,供用户继承以编写自己的数据集。在实现自定义Dataset
时,必须重写两个函数:__getitem__
和__len__
。
__getitem__
:需要实现读取单个样本的功能。通常接收一个索引(可以是序号或键),从磁盘中读取数据并进行预处理(包括在线数据增强),最后返回一个样本的数据。数据可以包括模型需要的输入、标签以及其他元信息(如图片路径)。__len__
:返回数据集的大小。如果这个函数返回0,DataLoader
会抛出错误:"ValueError: num_samples should be a positive integer value, but got num_samples=0"。这个错误通常是因为文件路径配置错误,导致数据集找不到任何数据。
理解了Dataset
类的概念后,我们通过一幅示意图来理解Dataset
与DataLoader
的关系:
Dataset
负责与磁盘交互,读取并预处理数据,而DataLoader
则负责将这些样本组装成批数据,并实现各种采样策略。采样的规则可以通过sampler
参数由用户自定义,可以方便地实现均衡采样、随机采样、有偏采样、渐进式采样等。
在实际应用中,我们通常需要通过一个_get_img_info
函数来建立与磁盘的关系。这个函数负责收集并处理样本的路径信息、标签信息,并存储到一个列表中,供__getitem__
函数使用。__getitem__
函数只需要拿到序号,就可以获得图片的路径信息、标签信息,接着进行图片预处理,最后返回一个样本信息。
三个Dataset案例
为了帮助大家掌握不同类型数据的读取,这里构建了三个不同形式的数据集案例:
- 数据的划分及标签在txt文件中
- 数据的划分及标签在文件夹中体现
- 数据的划分及标签在csv文件中
每个案例都详细展示了如何通过_get_img_info
函数收集数据信息,并在__getitem__
函数中读取和预处理数据。
第一个案例:数据标签在txt文件中
在这个案例中,数据的划分和标签信息存储在一个txt文件中。通过解析这个txt文件,我们可以获取每个样本的路径和标签信息。
第二个案例:数据标签在文件夹中体现
在这个案例中,数据的划分和标签信息通过文件夹结构来体现。每个类别对应一个文件夹,文件夹中的图片属于该类别。
第三个案例:数据标签在csv文件中
在这个案例中,数据的划分和标签信息存储在一个csv文件中。通过解析这个csv文件,我们可以获取每个样本的路径和标签信息。
代码输出分析
代码输出主要分为两部分:
__getitem__
输出:输出PIL对象及图像标签。可以看到在__getitem__
函数中采用了img = Image.open(path_img).convert('L')
对图片进行了读取,得到了PIL对象。由于没有设置transform,不对图像进行任何预处理,因此返回的图像是PIL对象。结合DataLoader的使用:这种形式更贴近真实场景。在这里为
Dataset
设置了一些transform,包括图像的缩放、转换为Tensor以及归一化。因此,__getitem__
返回的图像变为了张量的形式,并且在DataLoader
中组装成了batchsize的形式。大家可以尝试修改缩放的大小来观察输出,也可以注释normalize来观察它们的作用。
0 torch.Size([2, 1, 4, 4]) tensor([[[[-0.0431, -0.1216, -0.0980, -0.1373],
[-0.0667, -0.2000, -0.0824, -0.2392],
[-0.1137, 0.0353, 0.1843, -0.2078],
[ 0.0510, 0.3255, 0.3490, -0.0510]]],
[[[-0.3569, -0.2863, -0.3333, -0.4118],
[ 0.0196, -0.3098, -0.2941, 0.1059],
[-0.2392, -0.1294, 0.0510, -0.2314],
[-0.1059, 0.4118, 0.4667, 0.0275]]]]) torch.Size([2]) tensor([1, 0])
关于transform的系列方法以及工作原理,将在本章后半部分讲解数据增强部分再详细展开。
小结
本节介绍了torch.utils.data.Dataset
类的结构及工作原理,并通过三个案例实践,加深大家对自行编写Dataset
的认识。关于Dataset
的编写,torchvision
也有很多常用公开数据集的Dataset
模板,建议大家学习。下一小节将介绍DataLoader
类的使用。
补充学习建议
- IDE的debug:下一小节的代码将采用debug模式进行逐步分析,建议大家提前熟悉pycharm等IDE的debug功能。
- python的迭代器:相信很多初学者对代码中的“next(iter(train_set))”不太了解,这里建议大家了解iter概念、next概念、迭代器概念、以及双下划线函数概念。