文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

PythonOpencv使用ann神经网络识别手写数字功能

2024-04-02 19:55

关注

opencv中也提供了一种类似于Keras的神经网络,即为ann,这种神经网络的使用方法与Keras的很接近。
关于mnist数据的解析,读者可以自己从网上下载相应压缩文件,用python自己编写解析代码,由于这里主要研究knn算法,为了图简单,直接使用Keras的mnist手写数字解析模块。
本次代码运行环境为:
python 3.6.8
opencv-python 4.4.0.46
opencv-contrib-python 4.4.0.46

下面的代码为使用ann进行模型的训练:

from keras.datasets import mnist
from keras import utils
import cv2
import numpy as np
#opencv中ANN定义神经网络层
def create_ANN():
    ann=cv2.ml.ANN_MLP_create()
    #设置神经网络层的结构 输入层为784 隐藏层为80 输出层为10
    ann.setLayerSizes(np.array([784,64,10]))
    #设置网络参数为误差反向传播法
    ann.setTrainMethod(cv2.ml.ANN_MLP_BACKPROP)
    #设置激活函数为sigmoid
    ann.setActivationFunction(cv2.ml.ANN_MLP_SIGMOID_SYM)
    #设置训练迭代条件 
    #结束条件为训练30次或者误差小于0.00001
    ann.setTermCriteria((cv2.TermCriteria_EPS|cv2.TermCriteria_COUNT,100,0.0001))
    return ann
#计算测试数据上的识别率
def evaluate_acc(ann,test_images,test_labels):
    #采用的sigmoid激活函数,需要对结果进行置信度处理 
    #对于大于0.99的可以确定为1 对于小于0.01的可以确信为0
    test_ret=ann.predict(test_images)
    #预测结果是一个元组
    test_pre=test_ret[1]
    #可以直接最大值的下标 (10000,)
    test_pre=test_pre.argmax(axis=1)
    true_sum=(test_pre==test_labels)
    return true_sum.mean()
if __name__=='__main__':
    #直接使用Keras载入的训练数据(60000, 28, 28) (60000,)
    (train_images,train_labels),(test_images,test_labels)=mnist.load_data()
    #变换数据的形状并归一化
    train_images=train_images.reshape(train_images.shape[0],-1)#(60000, 784)
    train_images=train_images.astype('float32')/255
    test_images=test_images.reshape(test_images.shape[0],-1)
    test_images=test_images.astype('float32')/255
    #将标签变为one-hot形状 (60000, 10) float32
    train_labels=utils.to_categorical(train_labels)
    #测试数据标签不用变为one-hot (10000,)
    test_labels=test_labels.astype(np.int)
    
    #定义神经网络模型结构
    ann=create_ANN()
    #开始训练    
    ann.train(train_images,cv2.ml.ROW_SAMPLE,train_labels)
    #在测试数据上测试准确率
    print(evaluate_acc(ann,test_images,test_labels))
    
    #保存模型
    ann.save('mnist_ann.xml')
    #加载模型
    myann=cv2.ml.ANN_MLP_load('mnist_ann.xml')

训练100次得到的准确率为0.9376,可以接着增加训练次数或者提高神经网络的层次结构深度来提高准确率。
使用ann神经网络的模型结构非常小,因为只是保存了权重参数。

在这里插入图片描述

可以看到整个模型文件的大小才1M,而svm的大小为十多兆,knn的为几百兆,因此使用ann神经网络更加适合部署在客户端上。
接下来使用ann进行图片的测试识别:

import cv2
import numpy as np
if __name__=='__main__':
    #读取图片
    img=cv2.imread('shuzi.jpg',0)
    img_sw=img.copy()
    #将数据类型由uint8转为float32
    img=img.astype(np.float32)
    #图片形状由(28,28)转为(784,)
    img=img.reshape(-1,)
    #增加一个维度变为(1,784)
    img=img.reshape(1,-1)
    #图片数据归一化
    img=img/255
    #载入ann模型
    ann=cv2.ml.ANN_MLP_load('minist_ann.xml')
    #进行预测
    img_pre=ann.predict(img)
    #因为激活函数sigmoid,因此要进行置信度处理
    ret=img_pre[1]
    ret[ret>0.9]=1
    ret[ret<0.1]=0
    print(ret)
    cv2.imshow('test',img_sw)
    cv2.waitKey(0)

运行程序,结果如下,可见该模型正确识别了数字0.

在这里插入图片描述

到此这篇关于Python Opencv使用ann神经网络识别手写数字的文章就介绍到这了,更多相关python opencv识别手写数字内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯