文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

Pytorch数据读取与预处理的实现方法

2023-06-14 08:33

关注

这篇文章给大家分享的是有关Pytorch数据读取与预处理的实现方法的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。

  在炼丹时,数据的读取与预处理是关键一步。不同的模型所需要的数据以及预处理方式各不相同,如果每个轮子都我们自己写的话,是很浪费时间和精力的。Pytorch帮我们实现了方便的数据读取与预处理方法,下面记录两个DEMO,便于加快以后的代码效率。

  根据数据是否一次性读取完,将DEMO分为:

  1、串行式读取。也就是一次性读取完所有需要的数据到内存,模型训练时不会再访问外存。通常用在内存足够的情况下使用,速度更快。

  2、并行式读取。也就是边训练边读取数据。通常用在内存不够的情况下使用,会占用计算资源,如果分配的好的话,几乎不损失速度。

  Pytorch官方的数据提取方式尽管方便编码,但由于它提取数据方式比较死板,会浪费资源,下面对其进行分析。

1  串行式读取

1.1  DEMO代码

import torch from torch.utils.data import Dataset,DataLoader   class MyDataSet(Dataset):# ————1———— def __init__(self):    self.data = torch.tensor(range(10)).reshape([5,2])  self.label = torch.tensor(range(5)) def __getitem__(self, index):    return self.data[index], self.label[index] def __len__(self):    return len(self.data) my_data_set = MyDataSet()# ————2————my_data_loader = DataLoader( dataset=my_data_set,  # ————3———— batch_size=2,     # ————4———— shuffle=True,     # ————5———— sampler=None,     # ————6———— batch_sampler=None,  # ————7————  num_workers=0 ,    # ————8————  collate_fn=None,    # ————9————  pin_memory=True,    # ————10————  drop_last=True     # ————11————)for i in my_data_loader: # ————12———— print(i)

  注释处解释如下:

  1、重写数据集类,用于保存数据。除了 __init__() 外,必须实现 __getitem__() 和 __len__() 两个方法。前一个方法用于输出索引对应的数据。后一个方法用于获取数据集的长度。

  2~5、 2准备好数据集后,传入DataLoader来迭代生成数据。前三个参数分别是传入的数据集对象、每次获取的批量大小、是否打乱数据集输出。

  6、采样器,如果定义这个,shuffle只能设置为False。所谓采样器就是用于生成数据索引的可迭代对象,比如列表。因此,定义了采样器,采样都按它来,shuffle再打乱就没意义了。

  7、批量采样器,如果定义这个,batch_size、shuffle、sampler、drop_last都不能定义。实际上,如果没有特殊的数据生成顺序的要求,采样器并没有必要定义。torch.utils.data 中的各种 Sampler 就是采样器类,如果需要,可以使用它们来定义。

  8、用于生成数据的子进程数。默认为0,不并行。

  9、拼接多个样本的方法,默认是将每个batch的数据在第一维上进行拼接。这样可能说不清楚,并且由于这里可以探究一下获取数据的速度,后面再详细说明。

  10、是否使用锁页内存。用的话会更快,内存不充足最好别用。

  11、是否把最后小于batch的数据丢掉。

  12、迭代获取数据并输出。

1.2  速度探索

  首先看一下DEMO的输出:

Pytorch数据读取与预处理的实现方法

  输出了两个batch的数据,每组数据中data和label都正确排列,符合我们的预期。那么DataLoader是怎么把数据整合起来的呢?首先,我们把collate_fn定义为直接映射(不用它默认的方法),来查看看每次DataLoader从MyDataSet中读取了什么,将上面部分代码修改如下:

my_data_loader = DataLoader( dataset=my_data_set,   batch_size=2,       shuffle=True,       sampler=None,      batch_sampler=None,   num_workers=0 ,     collate_fn=lambda x:x, #修改处 pin_memory=True,     drop_last=True     )

  结果如下:

Pytorch数据读取与预处理的实现方法

  输出还是两个batch,然而每个batch中,单个的data和label是在一个list中的。似乎可以看出,DataLoader是一个一个读取MyDataSet中的数据的,然后再进行相应数据的拼接。为了验证这点,代码修改如下:

import torch from torch.utils.data import Dataset,DataLoader   class MyDataSet(Dataset):  def __init__(self):    self.data = torch.tensor(range(10)).reshape([5,2])  self.label = torch.tensor(range(5)) def __getitem__(self, index):    print(index)     #修改处2  return self.data[index], self.label[index] def __len__(self):    return len(self.data) my_data_set = MyDataSet() my_data_loader = DataLoader( dataset=my_data_set,   batch_size=2,       shuffle=True,       sampler=None,      batch_sampler=None,   num_workers=0 ,     collate_fn=lambda x:x, #修改处1 pin_memory=True,     drop_last=True     )for i in my_data_loader:  print(i)

  输出如下:

Pytorch数据读取与预处理的实现方法

  验证了前面的猜想,的确是一个一个读取的。如果数据集定义的不是格式化的数据,那还好,但是我这里定义的是tensor,是可以直接通过列表来索引对应的tensor的。因此,DataLoader的操作比直接索引多了拼接这一步,肯定是会慢很多的。一两次的读取还好,但在训练中,大量的读取累加起来,就会浪费很多时间了。

  自定义一个DataLoader可以证明这一点,代码如下:

import torch from torch.utils.data import Dataset,DataLoader from time import time  class MyDataSet(Dataset):  def __init__(self):    self.data = torch.tensor(range(100000)).reshape([50000,2])  self.label = torch.tensor(range(50000)) def __getitem__(self, index):    return self.data[index], self.label[index] def __len__(self):    return len(self.data)# 自定义DataLoaderclass MyDataLoader(): def __init__(self, dataset,batch_size):  self.dataset = dataset  self.batch_size = batch_size def __iter__(self):  self.now = 0  self.shuffle_i = np.array(range(self.dataset.__len__()))   np.random.shuffle(self.shuffle_i)  return self  def __next__(self):   self.now += self.batch_size  if self.now <= len(self.shuffle_i):   indexes = self.shuffle_i[self.now-self.batch_size:self.now]   return self.dataset.__getitem__(indexes)  else:   raise StopIteration# 使用官方DataLoadermy_data_set = MyDataSet() my_data_loader = DataLoader( dataset=my_data_set,   batch_size=256,       shuffle=True,       sampler=None,      batch_sampler=None,   num_workers=0 ,     collate_fn=None,  pin_memory=True,     drop_last=True     )start_t = time()for t in range(10): for i in my_data_loader:   passprint("官方:", time() - start_t)  #自定义DataLoadermy_data_set = MyDataSet() my_data_loader = MyDataLoader(my_data_set,256)start_t = time()for t in range(10): for i in my_data_loader:   passprint("自定义:", time() - start_t)

运行结果如下:

Pytorch数据读取与预处理的实现方法

  以上使用batch大小为256,仅各读取10 epoch的数据,都有30多倍的时间上的差距,更大的batch差距会更明显。另外,这里用于测试的每个数据只有两个浮点数,如果是图像,所需的时间可能会增加几百倍。因此,如果数据量和batch都比较大,并且数据是格式化的,最好自己写数据生成器。

2  并行式读取

2.1  DEMO代码

import matplotlib.pyplot as pltfrom torch.utils.data import DataLoader from torchvision import transforms from torchvision.datasets import ImageFolder  path = r'E:\DataSets\ImageNet\ILSVRC2012_img_train\10-19\128x128'my_data_set = ImageFolder(      #————1———— root = path,            #————2———— transform = transforms.Compose([  #————3————  transforms.ToTensor(),  transforms.CenterCrop(64) ]), loader = plt.imread         #————4————)my_data_loader = DataLoader( dataset=my_data_set,    batch_size=128,        shuffle=True,        sampler=None,        batch_sampler=None,     num_workers=0,       collate_fn=None,       pin_memory=True,       drop_last=True )      for i in my_data_loader:  print(i)

  注释处解释如下:

  1/2、ImageFolder类继承自DataSet类,因此可以按索引读取图像。路径必须包含文件夹,ImageFolder会给每个文件夹中的图像添加索引,并且每张图像会给予其所在文件夹的标签。举个例子,代码中my_data_set[0] 输出的是图像对象和它对应的标签组成的列表。

  3、图像到格式化数据的转换组合。更多的转换方法可以看 transform 模块。

  4、图像法的读取方式,默认是PIL.Image.open(),但我发现plt.imread()更快一些。

  由于是边训练边读取,transform会占用很多时间,因此可以先将图像转换为需要的形式存入外存再读取,从而避免重复操作。

  其中transform.ToTensor()会把正常读取的图像转换为torch.tensor,并且像素值会映射至[0,1][0,1]。由于plt.imread()读取png图像时,像素值在[0,1][0,1],而读取jpg图像时,像素值却在[0,255][0,255],因此使用transform.ToTensor()能将图像像素区间统一化。

感谢各位的阅读!关于“Pytorch数据读取与预处理的实现方法”这篇文章就分享到这里了,希望以上内容可以对大家有一定的帮助,让大家可以学到更多知识,如果觉得文章不错,可以把它分享出去让更多的人看到吧!

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯