PyTorch中如何使用DataLoader对数据集进行批处理
PyTorch中如何使用DataLoader对数据集进行批处理
最近在研究MNIST手写数据集的神经网络搭建时,发现了一个有趣的问题:一个数据集里面包含很多个数据,不能一次全部喂入网络,所以需要将数据分成一小块一小块地喂入。PyTorch中有一个很方便的DataLoader函数,可以帮我们轻松实现这一功能。下面我将通过一个简单的例子,来说明如何使用DataLoader进行数据批处理。
第一步:打开冰箱门
在使用DataLoader之前,我们需要创建一个PyTorch能够识别的数据集类型。虽然PyTorch中有很多现成的数据集类型,但在这里我们先从最基础的开始。
首先,我们建立两个向量X和Y,一个作为输入的数据,一个作为正确的结果:
随后,我们需要将X和Y组成一个完整的数据集,并将其转化为PyTorch能识别的数据集类型:
我们来看一下这些数据的数据类型:
可以看出,我们通过TensorDataset()
函数将X和Y拼装成了一个数据集,数据集的类型是TensorDataset
。
第二步:把大象装进去
接下来,我们将上一步创建的数据集放入DataLoader中,这样就可以生成一个迭代器,从而方便地进行批处理。
DataLoader中有很多参数,这里列举一些常用的:
dataset
:需要加载的数据集,类型为Datasetbatch_size
:每个batch加载多少样本shuffle
:是否在每个epoch都对数据进行洗牌sampler
:从数据集中采样样本的方法num_workers
:加载数据时使用的子进程数量collate_fn
:自定义的样本合并函数pin_memory
:是否将数据复制到CUDA固定内存中drop_last
:如果最后一个batch的样本数量小于batch_size,是否丢弃这个batch
第三步:把冰箱门关上
现在,我们可以愉快地使用上面定义好的迭代器进行训练了。在这里,我们用print来模拟训练过程,即对搭建好的网络进行数据喂入。
输出的结果是:
可以看到,我们一共训练了所有的数据5次。数据中一共有10组,我们设置的mini-batch是3,即每一次我们训练网络的时候喂入3组数据,到了最后一次我们只有1组数据了,比mini-batch小,我们就仅输出这一个。
此外,还可以利用Python中的enumerate()
函数,对所有可以迭代的数据类型(如包含很多元素的list)进行取操作,用法如下:
好啦,现在冰箱门就关上啦,(^__^)
本文原文来自:https://www.cnblogs.com/JeasonIsCoding/p/10168753.html