文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

快速学习一个算法,卷积神经网络

2024-11-29 18:34

关注

卷积神经网络(CNN)是一种专门用于处理具有网格结构数据的神经网络架构,最常见的应用领域是图像处理。

与传统的全连接神经网络不同,CNN 通过局部感知和参数共享来有效地处理高维数据,使其在图像分类、目标检测、语义分割等任务中表现出色。

图片

卷积神经网络的基本结构

卷积神经网络的核心思想是通过局部感知区域和权重共享来有效减少参数数量,同时保留空间信息。它通常由卷积层池化层以及全连接层组成。

卷积层

卷积层是 CNN 的核心部分,用来提取输入数据的局部特征。

卷积层通过多个卷积核对输入进行卷积操作,生成特征图(feature maps)。

如下图所示,对于大小为 7x7x3 的输入,应用两个卷积核,每个卷积核通过对三个输入通道进行卷积来提取不同的特征图。

图片

卷积核是一组权重,它们通过滑动窗口的方式在输入上进行卷积运算。

每个卷积核会与输入的局部区域进行点积,生成一个值,这些值组成输出特征图。

图片


卷积层通常有三个重要参数

图片

池化层

池化层用于对卷积层输出的特征图进行下采样,减少特征图的尺寸,从而减少计算量和内存需求,同时提高模型的鲁棒性。

池化层通常有最大池化(Max Pooling)和平均池化(Average Pooling)两种。

图片


全连接层

全连接层与普通的前馈神经网络类似,是 CNN 的后几层,它通常用在卷积层和池化层提取到的特征图之后,用来进行分类或回归任务。

全连接层的主要作用是对卷积层提取的特征进行进一步的组合和处理,从而输出模型的最终预测结果。

图片


案例分享

以下是一个使用卷积神经网络(CNN)进行手写数字识别的案例代码,基于经典的 MNIST 数据集。

import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt

(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()

train_images = train_images.reshape((train_images.shape[0], 28, 28, 1)).astype('float32') / 255
test_images = test_images.reshape((test_images.shape[0], 28, 28, 1)).astype('float32') / 255

# 搭建 CNN 模型
model = models.Sequential()
# 第一层卷积层:卷积核大小为 3x3,输出 32 个特征图
model.add(layers.Conv2D(32, (3, 3), activatinotallow='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))  # 2x2 最大池化层
# 第二层卷积层
model.add(layers.Conv2D(64, (3, 3), activatinotallow='relu'))
model.add(layers.MaxPooling2D((2, 2)))
# 第三层卷积层
model.add(layers.Conv2D(64, (3, 3), activatinotallow='relu'))
# 将特征图展平
model.add(layers.Flatten())
# 全连接层
model.add(layers.Dense(64, activatinotallow='relu'))
model.add(layers.Dense(10, activatinotallow='softmax'))



model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.summary()

history = model.fit(train_images, train_labels, epochs=5, 
                    validation_data=(test_images, test_labels))

test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print(f'\nTest accuracy: {test_acc}')

# 可视化训练过程
plt.plot(history.history['accuracy'], label='Accuracy rate')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

图片

来源:程序员学长内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯