文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

torch.utils.data.DataLoader与迭代器转换的方法

2023-06-29 05:51

关注

这篇文章主要介绍“torch.utils.data.DataLoader与迭代器转换的方法”的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇“torch.utils.data.DataLoader与迭代器转换的方法”文章能帮助大家解决问题。

在做实验时,我们常常会使用用开源的数据集进行测试。而Pytorch中内置了许多数据集,这些数据集我们常常使用DataLoader类进行加载。
如下面这个我们使用DataLoader类加载torch.vision中的FashionMNIST数据集。

from torch.utils.data import DataLoaderfrom torchvision import datasetsfrom torchvision.transforms import ToTensorimport matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(    root="data",    train=True,    download=True,    transform=ToTensor())test_data = datasets.FashionMNIST(    root="data",    train=False,    download=True,    transform=ToTensor())

我们接下来定义Dataloader对象用于加载这两个数据集:

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

那么这个train_dataloader究竟是什么类型呢?

print(type(train_dataloader))  # <class 'torch.utils.data.dataloader.DataLoader'>

我们可以将先其转换为迭代器类型。

print(type(iter(train_dataloader)))# <class 'torch.utils.data.dataloader._SingleProcessDataLoaderIter'>

然后再使用next(iter(train_dataloader))从迭代器里取数据,如下所示:

train_features, train_labels = next(iter(train_dataloader))print(f"Feature batch shape: {train_features.size()}")print(f"Labels batch shape: {train_labels.size()}")img = train_features[0].squeeze()label = train_labels[0]plt.imshow(img, cmap="gray")plt.show()print(f"Label: {label}")

可以看到我们成功获取了数据集中第一张图片的信息,控制台打印:

Feature batch shape: torch.Size([64, 1, 28, 28])Labels batch shape: torch.Size([64])Label: 2

图片可视化显示如下:

torch.utils.data.DataLoader与迭代器转换的方法

不过有读者可能就会产生疑问,很多时候我们并没有将DataLoader类型强制转换成迭代器类型呀,大多数时候我们会写如下代码:

for train_features, train_labels in train_dataloader:     print(train_features.shape) # torch.Size([64, 1, 28, 28])    print(train_features[0].shape) # torch.Size([1, 28, 28])    print(train_features[0].squeeze().shape) # torch.Size([28, 28])        img = train_features[0].squeeze()    label = train_labels[0]    plt.imshow(img, cmap="gray")    plt.show()    print(f"Label: {label}")

可以看到,该代码也能够正常迭代训练数据,前三个样本的控制台打印输出为:

torch.Size([64, 1, 28, 28])torch.Size([1, 28, 28])torch.Size([28, 28])Label: 7torch.Size([64, 1, 28, 28])torch.Size([1, 28, 28])torch.Size([28, 28])Label: 4torch.Size([64, 1, 28, 28])torch.Size([1, 28, 28])torch.Size([28, 28])Label: 1

那么为什么我们这里没有显式将Dataloader转换为迭代器类型呢,其实是Python语言for循环的一种机制,一旦我们用for ... in ...句式来迭代一个对象,那么Python解释器就会偷偷地自动帮我们创建好迭代器,也就是说

for train_features, train_labels in train_dataloader:

实际上等同于

for train_features, train_labels in iter(train_dataloader):

更进一步,这实际上等同于

train_iterator = iter(train_dataloader)try:    while True:        train_features, train_labels = next(train_iterator)except StopIteration:    pass

推而广之,我们在用Python迭代直接迭代列表时:

for x in [1, 2, 3, 4]:

其实Python解释器已经为我们隐式转换为迭代器了:

list_iterator = iter([1, 2, 3, 4])try:    while True:        x = next(list_iterator)except StopIteration:    pass

关于“torch.utils.data.DataLoader与迭代器转换的方法”的内容就介绍到这里了,感谢大家的阅读。如果想了解更多行业相关的知识,可以关注编程网行业资讯频道,小编每天都会为大家更新不同的知识点。

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯