以MNIST为例
from torchvision import datasets
mnist = datasets.MNIST(root='./data/', train=True, download=True)
print(mnist[0])
结果
(<PIL.Image.Image image mode=L size=28x28 at 0x196E3F1D898>, 5)
MINIST数据集的dataset是由一张图片和一个label组成的元组
dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=2, shuffle=True,collate_fn=lambda x:x)
for each in dataloader:
print(each)
break
结果
[(<PIL.Image.Image image mode=L size=28x28 at 0x2CB3B105630>, 0), (<PIL.Image.Image image mode=L size=28x28 at 0x2CB3B105668>, 2)]
collate_fn为lamda x:x时表示对传入进来的数据不做处理
下面自定义collate_fn看看什么效果
def collate(data):
img = []
label = []
for each in data:
img.append(each[0])
label.append(each[1])
return img,label
dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=2, shuffle=True,collate_fn=lambda x:collate(x))
for each in dataloader:
print(each)
break
结果
([<PIL.Image.Image image mode=L size=28x28 at 0x241433A36D8>, <PIL.Image.Image image mode=L size=28x28 at 0x241433A3710>], [9, 3])
说明:若不设置collate_fn参数则会使用默认处理函数
但必须保证传进来的数据都是tensor格式否则会报错
附:DataLoader完整的参数表如下:
class torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=<function default_collate>,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None)
DataLoader在数据集上提供单进程或多进程的迭代器
几个关键的参数意思:
- shuffle:设置为True的时候,每个世代都会打乱数据集
- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能
- drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留
总结
到此这篇关于Pytorch使用技巧之Dataloader中的collate_fn参数的文章就介绍到这了,更多相关Dataloader中的collate_fn参数内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!