torchvision中的数据集使用
1. torchvision中的数据集使用
官网文档
注意左上角的版本
注意点1 totensor实例化不要忘记加括号
totensor实例化不要忘记加括号,否则后面用数据集序列号的时候会报错
注意点2 download可以一直保持为True
download可以一直保持为True,下载一次后指定目录下有下载好的数据集,代码不会重复下载,也可以自己把下载好的数据集压缩包放到指定目录,代码会自动解压缩
代码
from torch.utils.tensorboard import SummaryWriterfrom torchvision import datasets, transforms# 用法1# 数据下载很慢的话 可以使用迅雷下载,属性里面可以看到迅雷是从多方下载的,速度比较快 https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gztrain_set = datasets.CIFAR10(root='./dataset', train=True, download=True)test_set = datasets.CIFAR10(root='./dataset', train=False, download=True)# 下载的数据集是图片类型,可以debug查看数据print(test_set[0]) # __getitem__ return img, targetprint(type(test_set[0]))img, target = test_set[0]print(target)print(test_set.classes[target])print(img)# PIL 图片可以直接show函数展示img.show()# 用法2# 将数据集批量调用transforms,使用tensor数据类型# trans_compose = transforms.Compose([transforms.ToTensor]) # 错误写法 会导致后面报错trans_compose = transforms.Compose([transforms.ToTensor()])train_set2 = datasets.CIFAR10(root='./dataset', train=True, transform=trans_compose, download=True)test_set2 = datasets.CIFAR10(root='./dataset', train=False, transform=trans_compose, download=True)print(type(test_set2[2]))img, target = test_set2[0]print(target)print(test_set2.classes[target])print(type(img))writer = SummaryWriter("logs")for i in range(10): img_tensor, target = test_set2[i] writer.add_image('tensor dataset', img_tensor, i)writer.close()
执行结果
> p11_torchvision_dataset.pyFiles already downloaded and verifiedFiles already downloaded and verified(<PIL.Image.Image image mode=RGB size=32x32 at 0x1CF47DA9E20>, 3)<class 'tuple'>3cat<PIL.Image.Image image mode=RGB size=32x32 at 0x1CF47DA9E20>Files already downloaded and verifiedFiles already downloaded and verified<class 'tuple'>3cat<class 'torch.Tensor'>Process finished with exit code 0
2. DataLoader的使用
来源地址:https://blog.csdn.net/weixin_42831564/article/details/132561251