文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

PyTorch Dataset与DataLoader使用超详细讲解

2024-04-02 19:55

关注

一、Dataset

Dataset 类提供一种方式去获取数据及其标签

主要有两个目的:

1. 在控制台进行操作

Hymenoptera (膜翅目昆虫)数据集下载地址:

链接: https://pan.baidu.com/s/1XKwXsAtE2yzZW2IsvBDpnw?pwd=8a5t

提取码: 8a5t 

这是一个蚂蚁蜜蜂二分类的数据集,通常数据集有以下三种组织形式(上面的数据集属于第一种):

①获取图片的基本信息

在Pycharm 中,点击下方的PythonConsole进入控制台进行操作(通过控制台可以看到变量的详细信息)

首先加载图片,逐行输入下方代码:

from PIL import Image
img_path = "./dataset/hymenoptera_data/train/ants/0013035.jpg"
img = Image.open(img_path)

此时我们就可以在右侧看到相关变量的信息:

点击img变量,可以查看图片的详细信息。通过控制台执行程序能够直观地获取后续操作所需的数据:

最后可以通过img.show()打开图片查看:

②获取文件的基本信息

同样还是在控制台逐行输入以下代码:

dir_path = "dataset/hymenoptera_data/train/ants"
import os
img_path_list = os.listdir(dir_path)
img_path_list[0]

我们就可以获取到文件夹下的文件名称,由于是使用控制台,我们还可以在右侧查看列表的详细信息:

因此在控制台操作是有很大的优点的,我们可以在控制台逐行执行已经编写好的文件中的语句,通过查看右侧变量的值来判断程序写的是否有问题

2. 编写一个继承Dataset 的类加载数据

下面的代码也可以在控制台运行(可以多行复制粘贴)来检验程序是否有误

①定义 MyData类

导入所需头文件:

from torch.utils.data import Dataset
from PIL import Image
import os

定义MyData类:

class MyData(Dataset):
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path = os.listdir(self.path)
    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label
    def __len__(self):
        return len(self.img_path)

其中os.path.join()可以实现多个路径的合并且不出错

②创建类的实例并调用

创建 MyData 类的实例:

if __name__ == "__main__":
    root_dir = "../dataset/hymenoptera_data/train"
    ants_label_dir = "ants"
    bees_label_dir = "bees"
    ants_dataset = MyData(root_dir, ants_label_dir)
    bees_dataset = MyData(root_dir, bees_label_dir)

调用类中写好的函数:

    img, label = ants_dataset.__getitem__(3)
    print(ants_dataset.__len__(), label)
    img.show()

同时我们也可以通过下面这种方式用已有的数据集来创造数据集:

train_dataset = ants_dataset + bees_dataset

二、DataLoader

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

加载数据:

test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)

测试:

img, target = test_data[0]
print(img.shape)
print(target)

进行日志记录,开始训练:

writer = SummaryWriter("dataloader")
for epoch in range(2):
    step = 0
    for data in test_loader:
        imgs, targets = data
        print(imgs.shape)
        print(targets)
        writer.add_images("Epoch: {}".format(epoch), imgs, step)
        step = step + 1
writer.close()

到此这篇关于PyTorch Dataset与DataLoader使用超详细讲解的文章就介绍到这了,更多相关PyTorch Dataset与DataLoader内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     801人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     348人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     311人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     432人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     220人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯