文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

使用TensorFlow训练图像分类模型的指南

2024-12-13 15:51

关注

众所周知,人类在很小的时候就学会了识别和标记自己所看到的事物。如今,随着机器学习和深度学习算法的不断迭代,计算机已经能够以非常高的精度,对捕获到的图像进行大规模的分类了。目前,此类先进算法的应用场景已经涵括到了包括:解读肺部扫描影像是否健康,通过移动设备进行面部识别,以及为零售商区分不同的消费对象类型等领域。

下面,我将和您共同探讨计算机视觉(Computer Vision)的一种应用——图像分类,并逐步展示如何使用TensorFlow,在小型图像数据集上进行模型的训练。

1、数据集和目标

在本示例中,我们将使用MNIST数据集的从0到9的数字图像。其形态如下图所示:

我们训练该模型的目的是为了将图像分类到其各自的标签下,即:它们在上图中各自对应的数字处。通常,深度神经网络架构会提供一个输入、一个输出、两个隐藏层(Hidden Layers)和一个用于训练模型的Dropout层。而CNN或卷积神经网络(Convolutional Neural Network)是识别较大图像的首选,它能够在减少输入量的同时,捕获到相关的信息。

2、准备工作

首先,让我们通过TensorFlow、to_categorical(用于将数字类的值转换为其他类别)、Sequential、Flatten、Dense、以及用于构建神经网络架构的 Dropout,来导入所有相关的代码库。您可能会对此处提及的部分代码库略感陌生。我会在下文中对它们进行详细的解释。

3、超参数

import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense, Dropout
params = {
'dropout': 0.25,
'batch-size': 128,
'epochs': 50,
'layer-1-size': 128,
'layer-2-size': 128,
'initial-lr': 0.01,
'decay-steps': 2000,
'decay-rate': 0.9,
'optimizer': 'adamax'
}
mnist = tf.keras.datasets.mnist
num_class = 10
# split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# reshape and normalize the data
x_train = x_train.reshape(60000, 784).astype("float32")/255
x_test = x_test.reshape(10000, 784).astype("float32")/255
# convert class vectors to binary class matrices
y_train = to_categorical(y_train, num_class)
y_test = to_categorical(y_test, num_class)

4、创建训练和测试集

由于TensorFlow库也包括了MNIST数据集,因此您可以通过调用对象上的 datasets.mnist ,再调用load_data() 的方法,来分别获取训练(60,000个样本)和测试(10,000个样本)的数据集。

接着,您需要对训练和测试的图像进行整形和归一化。其中,归一化会将图像的像素强度限制在0和1之间。

最后,我们使用之前已导入的to_categorical 方法,将训练和测试标签转换为已分类标签。这对于向TensorFlow框架传达输出的标签(即:0到9)为类(class),而不是数字类型,是非常重要的。

5、设计神经网络架构

下面,让我们来了解如何在细节上设计神经网络架构。

我们通过添加Flatten ,将2D图像矩阵转换为向量,以定义DNN(深度神经网络)的结构。输入的神经元在此处对应向量中的数字。

接着,我使用Dense() 方法,添加两个隐藏的密集层,并从之前已定义的“params”字典中提取各项超参数。我们可以将“relu”(Rectified Linear Unit)作为这些层的激活函数。它是神经网络隐藏层中最常用的激活函数之一。

然后,我们使用Dropout方法添加Dropout层。它将被用于在训练神经网络时,避免出现过拟合(overfitting)。毕竟,过度拟合模型倾向于准确地记住训练集,并且无法泛化那些不可见(unseen)的数据集。

输出层是我们网络中的最后一层,它是使用Dense() 方法来定义的。需要注意的是,输出层有10个神经元,这对应于类(数字)的数量。

# Model Definition
# Get parameters from logged hyperparameters
model = Sequential([
Flatten(input_shape=(784, )),
Dense(params('layer-1-size'), activatinotallow='relu'),
Dense(params('layer-2-size'), activatinotallow='relu'),
Dropout(params('dropout')),
Dense(10)
])
lr_schedule =
tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=experiment.get_parameter('initial-lr'),
decay_steps=experiment.get_parameter('decay-steps'),
decay_rate=experiment.get_parameter('decay-rate')
)
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
model.compile(optimizer='adamax',
loss=loss_fn,
metrics=['accuracy'])
model.fit(x_train, y_train,
batch_size=experiment.get_parameter('batch-size'),
epochs=experiment.get_parameter('epochs'),
validation_data=(x_test, y_test),)
score = model.evaluate(x_test, y_test)
# Log Model
model.save('tf-mnist-comet.h5')

6、训练

至此,我们已经定义好了架构。下面让我们用给定的训练数据,来编译和训练神经网络。

首先,我们以初始学习率、衰减步骤和衰减率作为参数,使用ExponentialDecay(指数衰减学习率)来定义学习率计划。

其次,将损失函数定义为CategoricalCrossentropy(用于多类式分类)。

接着,通过将优化器 (即:adamax)、损失函数、以及各项指标(由于所有类都同等重要、且均匀分布,因此我选择了准确性)作为参数,来编译模型。

然后,我们通过使用x_train、y_train、batch_size、epochs和validation_data去调用一个拟合方法,并拟合出模型。

同时,我们调用模型对象的评估方法,以获得模型在不可见数据集上的表现分数。

最后,您可以使用在模型对象上调用的save方法,保存要在生产环境中部署的模型对象。

7、小结

综上所述,我们讨论了为图像分类任务,训练深度神经网络的一些入门级的知识。您可以将其作为熟悉使用神经网络,进行图像分类的一个起点。据此,您可了解到该如何选择正确的参数集、以及架构背后的思考逻辑。

原文链接:https://www.kdnuggets.com/2022/12/guide-train-image-classification-model-tensorflow.html

来源:51CTO技术栈内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯