问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

PyTorch中如何使用DataLoader对数据集进行批处理

创作时间:
作者:
@小白创作中心

PyTorch中如何使用DataLoader对数据集进行批处理

引用
1
来源
1.
https://www.xin3721.com/Articlenet/29571.html

最近在研究MNIST手写数据集的神经网络搭建时,发现了一个有趣的问题:一个数据集里面包含很多个数据,不能一次全部喂入网络,所以需要将数据分成一小块一小块地喂入。PyTorch中有一个很方便的DataLoader函数,可以帮我们轻松实现这一功能。下面我将通过一个简单的例子,来说明如何使用DataLoader进行数据批处理。

第一步:打开冰箱门

在使用DataLoader之前,我们需要创建一个PyTorch能够识别的数据集类型。虽然PyTorch中有很多现成的数据集类型,但在这里我们先从最基础的开始。

首先,我们建立两个向量X和Y,一个作为输入的数据,一个作为正确的结果:

随后,我们需要将X和Y组成一个完整的数据集,并将其转化为PyTorch能识别的数据集类型:

我们来看一下这些数据的数据类型:

可以看出,我们通过TensorDataset()函数将X和Y拼装成了一个数据集,数据集的类型是TensorDataset

第二步:把大象装进去

接下来,我们将上一步创建的数据集放入DataLoader中,这样就可以生成一个迭代器,从而方便地进行批处理。

DataLoader中有很多参数,这里列举一些常用的:

  • dataset:需要加载的数据集,类型为Dataset
  • batch_size:每个batch加载多少样本
  • shuffle:是否在每个epoch都对数据进行洗牌
  • sampler:从数据集中采样样本的方法
  • num_workers:加载数据时使用的子进程数量
  • collate_fn:自定义的样本合并函数
  • pin_memory:是否将数据复制到CUDA固定内存中
  • drop_last:如果最后一个batch的样本数量小于batch_size,是否丢弃这个batch

第三步:把冰箱门关上

现在,我们可以愉快地使用上面定义好的迭代器进行训练了。在这里,我们用print来模拟训练过程,即对搭建好的网络进行数据喂入。

输出的结果是:

可以看到,我们一共训练了所有的数据5次。数据中一共有10组,我们设置的mini-batch是3,即每一次我们训练网络的时候喂入3组数据,到了最后一次我们只有1组数据了,比mini-batch小,我们就仅输出这一个。

此外,还可以利用Python中的enumerate()函数,对所有可以迭代的数据类型(如包含很多元素的list)进行取操作,用法如下:

好啦,现在冰箱门就关上啦,(^__^)

本文原文来自:https://www.cnblogs.com/JeasonIsCoding/p/10168753.html

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号