问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

PyTorch数据处理:Dataset类详解与实战

创作时间:
作者:
@小白创作中心

PyTorch数据处理:Dataset类详解与实战

引用
1
来源
1.
https://tingsongyu.github.io/PyTorch-Tutorial-2nd/chapter-3/3.1-dataset.html

PyTorch数据处理:Dataset类详解与实战

在PyTorch中,虽然DataLoader是数据模块的核心,但用户最常修改和直接接触的部分是Dataset。本节将深入分析Dataset的作用,并通过三个案例展示如何编写自定义的Dataset来读取不同类型的数据集。

Dataset的功能

PyTorch提供的torch.utils.data.Dataset类是一个抽象基类,供用户继承以编写自己的数据集。在实现自定义Dataset时,必须重写两个函数:__getitem____len__

  • __getitem__:需要实现读取单个样本的功能。通常接收一个索引(可以是序号或键),从磁盘中读取数据并进行预处理(包括在线数据增强),最后返回一个样本的数据。数据可以包括模型需要的输入、标签以及其他元信息(如图片路径)。
  • __len__:返回数据集的大小。如果这个函数返回0,DataLoader会抛出错误:"ValueError: num_samples should be a positive integer value, but got num_samples=0"。这个错误通常是因为文件路径配置错误,导致数据集找不到任何数据。

理解了Dataset类的概念后,我们通过一幅示意图来理解DatasetDataLoader的关系:

Dataset负责与磁盘交互,读取并预处理数据,而DataLoader则负责将这些样本组装成批数据,并实现各种采样策略。采样的规则可以通过sampler参数由用户自定义,可以方便地实现均衡采样、随机采样、有偏采样、渐进式采样等。

在实际应用中,我们通常需要通过一个_get_img_info函数来建立与磁盘的关系。这个函数负责收集并处理样本的路径信息、标签信息,并存储到一个列表中,供__getitem__函数使用。__getitem__函数只需要拿到序号,就可以获得图片的路径信息、标签信息,接着进行图片预处理,最后返回一个样本信息。

三个Dataset案例

为了帮助大家掌握不同类型数据的读取,这里构建了三个不同形式的数据集案例:

  1. 数据的划分及标签在txt文件中
  2. 数据的划分及标签在文件夹中体现
  3. 数据的划分及标签在csv文件中

每个案例都详细展示了如何通过_get_img_info函数收集数据信息,并在__getitem__函数中读取和预处理数据。

第一个案例:数据标签在txt文件中

在这个案例中,数据的划分和标签信息存储在一个txt文件中。通过解析这个txt文件,我们可以获取每个样本的路径和标签信息。

第二个案例:数据标签在文件夹中体现

在这个案例中,数据的划分和标签信息通过文件夹结构来体现。每个类别对应一个文件夹,文件夹中的图片属于该类别。

第三个案例:数据标签在csv文件中

在这个案例中,数据的划分和标签信息存储在一个csv文件中。通过解析这个csv文件,我们可以获取每个样本的路径和标签信息。

代码输出分析

代码输出主要分为两部分:

  1. __getitem__输出:输出PIL对象及图像标签。可以看到在__getitem__函数中采用了img = Image.open(path_img).convert('L')对图片进行了读取,得到了PIL对象。由于没有设置transform,不对图像进行任何预处理,因此返回的图像是PIL对象。

  2. 结合DataLoader的使用:这种形式更贴近真实场景。在这里为Dataset设置了一些transform,包括图像的缩放、转换为Tensor以及归一化。因此,__getitem__返回的图像变为了张量的形式,并且在DataLoader中组装成了batchsize的形式。大家可以尝试修改缩放的大小来观察输出,也可以注释normalize来观察它们的作用。

0 torch.Size([2, 1, 4, 4]) tensor([[[[-0.0431, -0.1216, -0.0980, -0.1373],
          [-0.0667, -0.2000, -0.0824, -0.2392],
          [-0.1137,  0.0353,  0.1843, -0.2078],
          [ 0.0510,  0.3255,  0.3490, -0.0510]]],

        [[[-0.3569, -0.2863, -0.3333, -0.4118],
          [ 0.0196, -0.3098, -0.2941,  0.1059],
          [-0.2392, -0.1294,  0.0510, -0.2314],
          [-0.1059,  0.4118,  0.4667,  0.0275]]]]) torch.Size([2]) tensor([1, 0])

关于transform的系列方法以及工作原理,将在本章后半部分讲解数据增强部分再详细展开。

小结

本节介绍了torch.utils.data.Dataset类的结构及工作原理,并通过三个案例实践,加深大家对自行编写Dataset的认识。关于Dataset的编写,torchvision也有很多常用公开数据集的Dataset模板,建议大家学习。下一小节将介绍DataLoader类的使用。

补充学习建议

  • IDE的debug:下一小节的代码将采用debug模式进行逐步分析,建议大家提前熟悉pycharm等IDE的debug功能。
  • python的迭代器:相信很多初学者对代码中的“next(iter(train_set))”不太了解,这里建议大家了解iter概念、next概念、迭代器概念、以及双下划线函数概念。
© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号