文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

Python深度学习albumentations数据增强库

2024-04-02 19:55

关注

数据增强的必要性

深度学习在最近十年得以风靡得益于计算机算力的提高以及数据资源获取的难度下降。一个好的深度模型往往需要大量具有label的数据,使得模型能够很好的学习这种数据的分布。而给数据打标签往往是一件耗时耗力的工作。
拿cv里的经典任务为例,classification需要人准确识别物品类别或者生物种类,object detection需要人工画出bounding box, 确定其坐标,semantic segmentation甚至需要在像素级别进行标签标注。对于一些专业领域的图像标注,依赖于专业人士的知识素养(例如医疗,遥感等),这无疑对有标签数据的收集带来了麻烦。

那么有没有什么方法能够在数据集规模很小的情况,尽可能提高模型的表现力呢?

1.transfer learning或者说是domain adaptation,这种方法期望降低源域与目标域之间的数据分布差异,使得具有大量标注数据的源域帮助提升模型的训练效果。

2.对现有数据进行数据增强深度学习能够学习到的空间不变性,像素级别的不变性特征都有限。所以对图片进行平移,缩放,旋转,改变色调值等方法,可以使得模型见过各种类型的数据,提高模型在测试数据上的判别力。

albumentations

上面我只是笼统的谈了下数据增强的必要性,对于其更加深刻的理解往往需要在实验中不断体会或者总结。

albumentations的安装

这步没什么好说,利用包管理工具直接安装。


pip install albumentations

albumentations的流水线工作方式

导入所需要的库


import albumentations as A
from PIL import Image
import numpy as np

读入数据这步需要其它库进行配合,可以利用CV2,PIL等,这里出于习惯我选择使用PIL


image_path = './your/image/path'
image = np.array(Image.open(image_path))  # 获得了一个[H, W, C]的三维数组

创建流水线


transform = A.Compose([
	A.Resize(width=256, height=256),
	A.HorizontalFlip(p=0.5),
	A.RandomBrightnessContrast(p=0.2)
])

A.Compose中需要传入一个list, list包含了一系列数据增强操作的对象。这里可以理解为A.Compose返回一条工业流水线, 第一步进行A.Resize操作,将图片缩放成256 * 256;第二步在上一步的基础上以0.5的概率对图片进行镜像翻转(p这个参数代表进行这个操作的概率);第三步同理,对第一步第二步处理完的图像以0.2的概率进行亮度和对比度的改变。

transform就是我们将要对图片进行的操作流程,下一步就需要将图片数据传入进去。

获得数据增强完的图片数据


transformed = transform(image=image)
tranformed_image = transformed['image']

将图片数据传递给transform(很明显这是个可调用的对象)的image参数,它会返回一个处理完的对象,对象的key值image对应的value就是处理完的图像数据。

图像处理结果展示

在这里插入图片描述

object detection的数据增强

上述对albumentations流水线工作过程的简要说明其实就是classification任务的大致流程。
当然,albumentations如果仅仅只能做到上述的功能,那么torchvision中transform API可以把它完全替代,并且它也满足不了大多数cv任务的数据增强需求。

拿object detection为例,一张图片数据往往对应了若干个bounding box,如果你对图片数据进行的操作具有空间变换性,那么原有的bounding box数据画出的目标框必然已经对应不了图片中的对象了。
所以对图片数据进行变换的同时也必须对bounding box数据进行变换,保持二者的一致性。

绘制目标框

在介绍object detection的数据增强之前,先介绍一个绘制目标框的函数。在albumentation中展示的代码是用cv2实现,个人觉得画出的bounding box不太美观,下面使用的是matplotlib实现的代码。


import matplotlib.pyplot as plt
import matplotlib.patches as patches
def visualize_bbox(img, bbox, class_name, color, ax):
	"""
	img:图片数据 (H, W, C)数据格式
	bbox:array或者tensor, 假定数据格式是 [x_mid, y_mid, width, height]
	classname:str 目标框对应的种类
	color:str
	thickness:目标框的宽度
	"""
	x_mid, y_mid, width, height = bbox
	x_min = int(x_mid - width / 2)
	y_min = int(y_mid - height / 2)
	# 画目标检测框
	rect = patches.Rectangle((x_min, y_min), 
								width, 
								height, 
								linewidth=3,
								edgecolor=color,
								facecolor="none"
								)
	ax.imshow(img)
	ax.add_patch(rect)
	ax.text(x_min + 1, y_min - 3, class_name, fontSize=10, bbox={'facecolor':color, 'pad': 3, 'edgecolor':color})
def visualize(img, bboxes, category_ids, category_id_to_name, category_id_to_color):
	fig, ax = plt.subplots(1, figsize=(8, 8))
	ax.axis('off')
	for box, category in zip(bboxes, category_ids):
		class_name = category_id_to_name[category]
		color = category_id_to_color[category]
		visualize_bbox(img, box, class_name, color, ax)
	plt.show()

在这里插入图片描述

对bounding box进行空间变换

导入所需要的库


import albumentations as A
from PIL import Image
import numpy as np
image_path = './your/image/path'
image = np.array(Image.open(image_path))

构造流水线


transform = A.Compose([
	A.Resize(width=256, height=256),
	A.HorizontalFlip(p=0.5),
	A.RandomBrightnessContrast(p=0.2)
], bbox_params = A.BboxParams(format='yolo'))

相较于最简单的流水线(for classification),oject detection需要传入一个叫做bbox_params的参数,它接收的是用于配置bounding box参数的对象。
format表示的是bounding box数据的格式,albumentations提供了4种格式。

在这里插入图片描述

1.pascal_voc [x_min, y_min, x_max, y_max] 数值并没有归一化

   直接使用像素值[98, 345, 420, 462]

2.albumentations [x_min, y_min, x_max, y_max] 与上一种格式不一样的是

    这里值都是normalized 做了归一化处理[0.153125, 0.71875, 0.65625, 0.9625]

3.coco [x_min, y_min, width, height] 没有归一化

4.yolo [x_center, y_center, width, height] 归一化了

传入image数据和bounding box数据进行变换


label = np.array([
        [0.339, 0.6693333333333333, 0.402, 0.42133333333333334],
        [0.379, 0.5666666666666667, 0.158, 0.3813333333333333],
        [0.612, 0.7093333333333333, 0.084, 0.3466666666666667],
        [0.555, 0.7026666666666667, 0.078, 0.34933333333333333]
])  # normalized (x_center, y_center, width, height) 对应format yolo
category_ids = [12, 14, 14, 14]
category_id_to_name = {
    12: 'horse',
    14: 'people'
}
category_id_to_color = {
    12: 'yellow',
    14: 'red'
}
transformed = transform(image=image,bboxes=label)
transformed_image = transformed['image']
transformed_bboxes = transformed['bboxes']
height, width, _ = transformed_image.shape
transformed_bboxes[:, [0, 2]] = transformed_bboxes[:, [0, 2]] * width
transformed_bboxes[:, [1, 3]] = transformed_bboxes[:, [1, 3]] * height
visualize(transformed_image, transformed_bboxes, category_ids, category_id_to_name, category_id_to_color)

在这里插入图片描述

BboxParams中不止format这一个参数。当我们做随机裁剪操作的时候,bounding box完全可能只保留了一部分,当保留比例小于某一个阈值的时候,我们可以将其drop掉,具体的操作细节可以查看albumentations的相关教程。

semantic segmentation的数据增强

object detection和semantic segmentation在像素级别的data agumentation和classification没什么区别,而在空间变换上segmentation没有bounding box变换,与之对应的是mask变换。
mask是像素级别的label,与原图中的像素一一对应。
albumentations上的教程使用的是kaggle上的数据集,这里为了方便展示我们使用同样的数据集。

数据集网址

在这里插入图片描述

下载完数据并解压缩完成后可以得到如上的目录结构,通过train.csv文件可以得到所用的image和mask名称。


image = np.array(Image.open(image_path))  # 这里使用的是/train/images/0fea4b5049.png
mask = np.array(Image.open(mask_path))  # /train/masks/0fea4b5049.png

下面介绍一下展示结果的函数


from matplotlib import pyplot as plt
def visualize(image, mask, original_image=None, original_mask=None):
	fontsize=8
	if original_image == None and original_mask == None:
		fg, ax = plt.subplots(2, 1, figsize=(8, 8))
		ax[0].axis('off')
		ax[0].imshow(image)
		ax[0].set_title('image', fontsize=fontsize)
		ax[1].axis('off')
		ax[1].imshow(mask)
		ax[1].set_title('mask', fontsize=fontsize)
	else:
		fg, ax = plt.subplots(2, 2, figsize=(8, 8))
		ax[0, 0].axis('off')
		ax[0, 0].imshow(original_image)
		ax[0, 0].set_title('Original Image', fontsize=fontsize)
		ax[0, 1].axis('off')
		ax[0, 1].imshow(original_mask)
		ax[0, 1].set_title('Original Mask', fontsize=fontsize)
		ax[1, 0].axis('off')
		ax[1, 0].imshow(image)
		ax[1, 0].set_title('Transformed Image', fontsize=fontsize)
		ax[1, 1].axis('off')
		ax[1, 1].imshow(mask)
		ax[1, 1].set_title('Transformed Mask', fontsize=fontsize)	

data agumentation的流水线操作


aug = A.PadIfNeeded(min_height=128, min_width=128, p=1)
augmented = aug(image=image, mask=mask)
augmented_img = augmented['image']
augmented_mask = augmented['mask']
visualize(augmented_img, augmented_mask, original_image=image, original_mask=mask)

这里相较于classification就是多了个mask函数,将mask数据直接传进入即可。

在这里插入图片描述

padding的填充方式默认是reflection, 可以看到变换以后的mask右侧多了些黄色区域。
对于一些分割任务而言,我们不想增加或者删除额外的信息,所以往往采用 Non destructive transformations(非破坏性变换)如HorizontalFlip(水平翻转), VerticalFlip(垂直翻转), RandomRotate90(Randomly rotates by 0, 90, 180, 270 degrees)


aug = A.RandomRotate(p=1)
augmented = aug(image=image, mask=mask)
augmented_image = augmented['image']
augmented_mask = augmented['mask']
visualize(augmented_image, augmented_mask, original_image=image, original_mask=mask)

在这里插入图片描述

下面介绍下多个transform综合起来的流水线操作


original_height, original_width = image.shape[:2]
aug = A.Compose([
    A.OneOf([
        A.RandomSizedCrop(min_max_height=(50, 101), height=original_height, width=original_width, p=0.5),
        A.PadIfNeeded(min_height=original_height, min_width=original_width, p=0.5)
    ]),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.OneOf([
        A.ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
        A.GridDistortion(p=0.5),
        A.OpticalDistortion(distort_limit=1, shift_limit=0.5, p=1)
    ], p=0.8)
])
augmented = aug(image=image, mask=mask)
image_medium = augmented['image']
mask_medium = augmented['mask']
visualize(image_medium, mask_medium, original_image=image, original_mask=mask)

这里一个较新的知识点是A.OneOf,它接收的transform对象的list, 从中按照权重随机选择一个进行变换,它本身也有概率。

在这里插入图片描述

可以看到OneOf将list中的transform的概率进行归一化再重新分配。所以这里transform的p不再理解为概率,而是权重,取到1,甚至比1大都没有关系。

以上就是Python深度学习albumentations数据增强库的详细内容,更多关于Python数据增强库albumentations的资料请关注编程网其它相关文章!

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯