文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

使用PyTorch怎么读取数据

2023-06-14 08:21

关注

本篇文章给大家分享的是有关使用PyTorch怎么读取数据,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。

模块介绍

import zipfile # 解压import pandas as pd # 操作数据import os # 操作文件或文件夹import cv2 # 图像操作库import matplotlib.pyplot as plt # 图像展示库from torch.utils.data import Dataset # PyTorch内置对象from torchvision import transforms # 图像增广转换库 PyTorch内置import torch

初步读取数据

数据下载到此处
我们先初步编写一个脚本来实现图片的展示

# 解压文件到指定目录def unzip_file(root_path, filename):  full_path = os.path.join(root_path, filename)  file = zipfile.ZipFile(full_path)  file.extractall(root_path)unzip_file(root_path, zip_filename)# 读入csv文件face_landmarks = pd.read_csv(os.path.join(extract_path, csv_filename))# pandas读出的数据如想要操作索引 使用ilocimage_name = face_landmarks.iloc[:,0]landmarks = face_landmarks.iloc[:,1:]# 展示def show_face(extract_path, image_file, face_landmark):  plt.imshow(plt.imread(os.path.join(extract_path, image_file)), cmap='gray')  point_x = face_landmark.to_numpy()[0::2]  point_y = face_landmark.to_numpy()[1::2]  plt.scatter(point_x, point_y, c='r', s=6)  show_face(extract_path, image_name.iloc[1], landmarks.iloc[1])

使用PyTorch怎么读取数据

使用内置库来实现

实现MyDataset

使用内置库是我们的代码更加的规范,并且可读性也大大增加
继承Dataset,需要我们实现的有两个地方:

class FaceDataset(Dataset):  def __init__(self, extract_path, csv_filename, transform=None):    super(FaceDataset, self).__init__()    self.extract_path = extract_path    self.csv_filename = csv_filename    self.transform = transform    self.face_landmarks = pd.read_csv(os.path.join(extract_path, csv_filename))  def __len__(self):    return len(self.face_landmarks)  def __getitem__(self, idx):    image_name = self.face_landmarks.iloc[idx,0]    landmarks = self.face_landmarks.iloc[idx,1:].astype('float32')    point_x = landmarks.to_numpy()[0::2]    point_y = landmarks.to_numpy()[1::2]    image = plt.imread(os.path.join(self.extract_path, image_name))    sample = {'image':image, 'point_x':point_x, 'point_y':point_y}    if self.transform is not None:      sample = self.transform(sample)    return sample

测试功能是否正常

face_dataset = FaceDataset(extract_path, csv_filename)sample = face_dataset[0]plt.imshow(sample['image'], cmap='gray')plt.scatter(sample['point_x'], sample['point_y'], c='r', s=2)plt.title('face')

使用PyTorch怎么读取数据

实现自己的数据处理模块

内置的在torchvision.transforms模块下,由于我们的数据结构不能满足内置模块的要求,我们就必须自己实现
图片的缩放,由于缩放后人脸的标注位置也应该发生对应的变化,所以要自己实现对应的变化

class Rescale(object):  def __init__(self, out_size):    assert isinstance(out_size,tuple) or isinstance(out_size,int), 'out size isinstance int or tuple'    self.out_size = out_size  def __call__(self, sample):    image, point_x, point_y = sample['image'], sample['point_x'], sample['point_y']    new_h, new_w = self.out_size if isinstance(self.out_size,tuple) else (self.out_size, self.out_size)    new_image = cv2.resize(image,(new_w, new_h))    h, w = image.shape[0:2]    new_y = new_h / h * point_y    new_x = new_w / w * point_x    return {'image':new_image, 'point_x':new_x, 'point_y':new_y}

将数据转换为torch认识的数据格式因此,就必须转换为tensor
注意: cv2matplotlib读出的图片默认的shape为N H W C,而torch默认接受的是N C H W因此使用tanspose转换维度,torch转换多维度使用permute

class ToTensor(object):  def __call__(self, sample):    image, point_x, point_y = sample['image'], sample['point_x'], sample['point_y']    new_image = image.transpose((2,0,1))    return {'image':torch.from_numpy(new_image), 'point_x':torch.from_numpy(point_x), 'point_y':torch.from_numpy(point_y)}

测试

transform = transforms.Compose([Rescale((1024, 512)), ToTensor()])face_dataset = FaceDataset(extract_path, csv_filename, transform=transform)sample = face_dataset[0]plt.imshow(sample['image'].permute((1,2,0)), cmap='gray')plt.scatter(sample['point_x'], sample['point_y'], c='r', s=2)plt.title('face')

使用PyTorch怎么读取数据

使用Torch内置的loader加速读取数据

data_loader = DataLoader(face_dataset, batch_size=4, shuffle=True, num_workers=0)for i in data_loader:  print(i['image'].shape)  break
torch.Size([4, 3, 1024, 512])

注意: windows环境尽量不使用num_workers会发生报错

以上就是使用PyTorch怎么读取数据,小编相信有部分知识点可能是我们日常工作会见到或用到的。希望你能通过这篇文章学到更多知识。更多详情敬请关注编程网行业资讯频道。

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯