文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

MNIST机器学习入门

2023-01-31 07:58

关注

当我们开始学习编程的时候,第一件事往往是学习打印"Hello World"。就好比编程入门有Hello World,机器学习入门有MNIST。

MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片。它也包含每一张图片对应的标签,告诉我们这个是数字几。

详细内容请参考:http://wiki.jikexueyuan.com/p...

文章末尾会给出相关python代码,运行环境是python3.6+anaconda+tensorflow,具体环境搭建本文不做阐述。

一、MNIST简介

官网链接:http://yann.lecun.com/exdb/mn...

这个MNIST数据库是一个手写数字的数据库,它提供了六万的训练集和一万的测试集。

它的图片是被规范处理过的,是一张被放在中间部位的28px*28px的灰度图。

总共4个文件:

train-images-idx3-ubyte: training set images
train-labels-idx1-ubyte: training set labels
t10k-images-idx3-ubyte: test set images
t10k-labels-idx1-ubyte: test set labels

图片都被转成二进制放到了文件里面,

所以,每一个文件头部几个字节都记录着这些图片的信息,然后才是储存的图片信息。

二、tensorflow手写数字识别步骤

1、 将要识别的图片转为灰度图,并且转化为28*28矩阵

2、 将28*28的矩阵转换成1维矩阵

3、 用一个1*10的向量代表标签,因为数字是0~9,如数字1对应的矩阵就是:[0,1,0,0,0,0,0,0,0,0]

4、 softmax回归预测图片是哪个数字的概率。

这里顺带说一下还有一个回归:logistic,因为这里我们表示的状态不只两种,因此需要使用softmax

三、标签介绍(有监督学习/无监督学习)

监督学习:利用一组已知类别的样本调整分类器的参数,使其达到所要求性能的过程,也称为监督训练或有教师学习举个例子,MNIST自带了训练图片和训练标签,每张图片都有一个对应的标签,比如这张图片是1,标签也就是1,用他们训练程序,之后程序也就能识别测试集中的图片了,比如给定一张2的图片,它能预测出他是2

无监督学习:其中很重要的一类叫聚类举个例子,如果MNIST中只有训练图片,没有标签,我们的程序能够根据图片的不同特征,将他们分类,但是并不知道他们具体是几,这个其实就是“聚类”

标签的表示

在这里标签的表示方式有些特殊,它也是使用了一个一维数组,而不是单纯的数字,上面也说了,他是一个一位数组,0表示方法[1,0,0,0,0,0,0,0,0,0],1表示[0,1,0,0,0,0,0,0,0,0],………,

主要原因其实是这样的,因为softmax回归处理后会生成一个1*10的数组,数组[0,0]的数字表示预测的这张图片是0的概率,[0,1]则表示这张图片表示是1的概率……以此类推,这个数组表示的就是这张图片是哪个数字的概率(已经归一化),

因此,实际上,概率最大的那个数字就是我们所预测的值。两者对应来看,标准的标签就是表示图片对应数字的概率为100%,而表示其它数字的概率为0,举个例子,0表示[1,0,0,0,0,0,0,0,0,0],可以理解为它表示0的概率为100%,而表示别的数字的概率为0.

softmax回归

这是一个分类器,可以认为是Logistic回归的扩展,Logistic大家应该都听说过,就是生物学上的S型曲线,它只能分两类,用0和1表示,这个用来表示答题对错之类只有两种状态的问题时足够了,但是像这里的MNIST要把它分成10类,就必须用softmax来进行分类了。

P(y=0)=p0,P(y=1)=p1,p(y=2)=p2……P(y=9)=p9.这些表示预测为数字i的概率,(跟上面标签的格式正好对应起来了),它们的和为1,即 ∑(pi)=1。

tensorflow实现了这个函数,我们直接调用这个softmax函数即可,对于原理,可以参考下面的引文,这里只说一下我们这个MNIST demo要用softmax做什么。

(注:每一个神经元都可以接收来自网络中其他神经元的一个或多个输入信号,神经元与神经元之间都对应着连接权值,所有的输入加权和决定该神经元是处于激活还是抑制状态。感知器网络的输出只能取值0或1,不具备可导性。而基于敏感度的训练算法要求其输出函数必须处处可导,于是引入了常见的S型可导函数,即在每个神经元的输出之前先经过S型激活函数的处理。)

交叉熵

通俗一点就是,方差大家都知道吧,用它可以衡量预测值和实际值的相差程度,交叉熵其实也是一样的作用,那为什么不用方差呢,因为看sigmoid函数的图像就会发现,它的两侧几乎就是平的,导致它的方差在大部分情况下很小,这样在训练参数的时候收敛地就会很慢,交叉熵就是用来解决这个问题的,它的公式是 −∑y′log(y) ,其中,y是我们预测的概率分布,y’是实际的分布。

梯度下降

上面那步也说了,有个交叉熵,根据大伙对方差的理解,值越小,自然就越好,因此我们也要训练使得交叉熵最小的参数,这里梯度下降法就派上用场了,这个解释见上一篇系列文章吧,什么叫训练参数呢,可以想象一下,我们先用实际的值在二位坐标上画一条线,然后我们希望我们预测出来的那些值要尽可能地贴近这条线,我们假设生成我们这条线的公式ax+ax^2+bx^3+…..,我们需要生成这些系数,要求得这些系数,我们就需要各种点代入,然后才能求出,所以其实训练参数跟求参数是个类似的过程。

预测

训练结束以后我们就可以用这个模型去预测新的图片了,大概意思就是输入对应的值就能获取相应的结果。

代码如下:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import gzip
import os
import tempfile

import numpy
from six.moves import urllib
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
import tensorflow.examples.tutorials.mnist.input_data
"""MNIST机器学习"""
"""下载并读取数据源"""
mnist = read_data_sets("MNIST_data/", one_hot=True)
"""x是784维占位符,基础图像28x28=784"""
x = tf.placeholder(tf.float32,[None,784])
"""W表示证据值向量,因为总数据量为0~9,每一维对应不同的数字,因此为10"""
W = tf.Variable(tf.zeros([784,10]))
"""b同理,为10"""
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W)+b)
"""初始化交叉熵的值"""
y_ = tf.placeholder("float",[None,10])
"""计算交叉熵"""
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
"""梯度下降算法"""
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
"""开始训练模型"""
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
    batch_xs,batch_ys = mnist.train.next_batch(100)
    sess.run(train_step,feed_dict={x:batch_xs,y_:batch_ys})
"""评估模型"""
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
"""正确率"""
k = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
print(k)
阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯