文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

torchvision.transforms 数据预处理:ToTensor()

2023-10-08 06:07

关注

文章目录

ToTensor() 是pytorch中的数据预处理函数,包含在 torchvision.transforms 模块下。一般用于处理图像数据,所以其处理对象是 PIL Image 和 numpy.ndarray 。

1、ToTensor() 函数的作用

必须要声明不能只看函数名,就以为 ToTensor() 只是将图像转为 tensor,其实它的功能不止于此

看一下 ToTensor() 函数的源码:

class ToTensor:    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript.    Converts a PIL Image or numpy.ndarray (H x W x C) in the range    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]    if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)    or if the numpy.ndarray has dtype = np.uint8    In the other cases, tensors are returned without scaling.    .. note::        Because the input image is scaled to [0.0, 1.0], this transformation should not be used when        transforming target image masks. See the `references`_ for implementing the transforms for image masks.    .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation    """

大意是:

(1)将 PIL Image 或 numpy.ndarray 转为 tensor

(2)如果 PIL Image 属于 (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) 中的一种图像类型,或者 numpy.ndarray 格式数据类型是 np.uint8 ,则将 [0, 255] 的数据转为 [0.0, 1.0] ,也就是说将所有数据除以 255 进行归一化。

(3)将 HWC 的图像格式转为 CHW 的 tensor 格式。CNN训练时需要的数据格式是[N,C,N,W],也就是说经过 ToTensor() 处理的图像可以直接输入到CNN网络中,不需要再进行reshape。

2、读取图像时 PIL 和 opencv 的选择

在自己建立 dataset 迭代器时,一般操作是检索数据集图像的路径,然后使用 PIL 库或 opencv库读取图片路径。

2.1 使用 PIL

import numpy as npfrom PIL import ImagefilePath="Dataset/FFHQ/00000.png"img1=Image.open(filePath)print(f"img1 = {img1}")    # img1 = img2 = np.array(img1)print(f"img2 = {img2}")"""img2 = [[[  0 130 146]  [  0 128 144]  [  0 125 141]  ...  [133 162 164]  [133 157 159]  [134 157 163]]]"""

可以看到,使用 PIL.Image 读取的图像是一种 PIL 类,mode=RGB,要想获得图像的像素值还需要将其转为 np.array 格式。

而 opencv 可以直接将图像读取为 np.array 格式,因此首选 opencv 。

2.2 使用 opencv

import cv2filePath="Dataset/FFHQ/00000.png"img=cv2.imread(filePath)print(f"img.shape = {img.shape}")     # img.shape = (128, 128, 3)print(f"img = {img}")     # img.dtype = uint8"""img = [[[146 130   0]  [144 128   0]  [141 125   0]  ...  [164 162 133]  [159 157 133]  [163 157 134]]]"""

仔细对比PIL 和 opencv 的输出结果可以发现,PIL默认输出的图片格式为 RGB,而opencv输出的是BGR格式。

使用opencv读取的图像是[H,W,C]大小的,数据格式是 np.uint8 ,经过 ToTensor() 会进行归一化。而其他的数据类型(如 np.int8)经过 ToTensor() 数值不变,不进行归一化,后面会详细讲述。并且经过ToTensor()后图像格式变为 [C,H,W]。

3、ToTensor() 的使用

3.1 关键知识点

不管是使用 PLT还是opencv,最终得到都是 np.array类型。因此:

ToTensor() 是将 np.array 的数据 转为 tensor 格式

这里一定要明确几个点:

(1)np.array 整型的默认数据类型为 np.int32,经过 ToTensor() 后数值不变,不进行归一化。(2)np.array 浮点型的默认数据类型为 np.float64,经过 ToTensor() 后数值不变,不进行归一化。(3)opencv 读取的图像格式为 np.array,其数据类型为 np.uint8    经过 ToTensor() 后数值由 [0,255] 变为 [0,1],通过将每个数据除以255进行归一化。(4)经过 ToTensor() 后,HWC 的图像格式变为 CHW 的 tensor 格式。(5)np.uint8 和 np.int8 不一样,uint8是无符号整型,数值都是正数。(6)ToTensor() 可以处理任意 shape 的 np.array,并不只是三通道的图像数据。

3.2 代码示例

下面通过代码熟悉 ToTensor() 的使用,分别看一下 np.uint8 和 非 np.uint8 类型的 np.array 经过 ToTensor() 之后的输出。

(1) np.uint8 类型

import numpy as npfrom torchvision import transformsdata = np.array([    [0, 5, 10, 20, 0],    [255, 125, 180, 255, 196]], dtype=np.uint8)tensor = transforms.ToTensor()(data)print(tensor)"""tensor([[[0.0000, 0.0196, 0.0392, 0.0784, 0.0000],         [1.0000, 0.4902, 0.7059, 1.0000, 0.7686]]])"""

(2)非 np.uint8 类型

import numpy as npfrom torchvision import transformsdata = np.array([    [0, 5, 10, 20, 0],    [255, 125, 180, 255, 196]])      # data.dtype = int32tensor = transforms.ToTensor()(data)print(tensor)"""tensor([[[  0,   5,  10,  20,   0],         [255, 125, 180, 255, 196]]], dtype=torch.int32)"""

来源地址:https://blog.csdn.net/qq_43799400/article/details/127785104

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     807人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     351人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     314人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     433人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     221人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯