手写体数字识别(Python+TensorFlow)
注意:该作者博客已迁移至https://buxianshan.xyz
先看结果
在MNIST数据集10000张测试图片上的正确率
测试手写数字图片(20张)
原图
测试结果
源文件下载:
- CSDN下载 https://download.csdn.net/download/qq_43479622/11227413(需要5C币)
- 没有C币的也可以到GitHub下载 https://github.com/BuXianShan/Handwritten-Numeral-Recognition
声明:
本文大部分程序参考《TensorFlow实战Google深度学习框架》,很适合深度学习入门的书籍。
解压文件打开后如图
__pycache__文件夹是Python自动生成的,不用管它,想了解它可以参考 https://blog.csdn.net/yitiaodashu/article/details/79023987
MNIST_model文件夹保存了已经训练30000次的模型
picture文件夹存放的是自己手写数字的图片
需要的安装的Python模块
- tensorflow
- opencv-python
- pillow
(这里我建议全部使用pip安装,用国内镜像下载特别快,如果以后计算量大需要使用gpu版本的tensorflow再重新安装gpu版本的。opencv-python和pillow都是关于图像处理的,只有app.py文件使用到了。)
关于MNIST数据集
MNIST是深度学习的经典入门demo,他是由6万张训练图片和1万张测试图片构成的,每张图片都是28*28的灰度图,像素取值为0~1。这些图片是采集的不同的人手写从0到9的数字。TensorFlow将这个数据集和相关操作封装到了库中,每一张图片是一个长度为784的一维数组。
from tensorflow.examples.tutorials.mnist import input_data
便会自动下载封装好的数据集。
mnist_inference.py文件定义了前向传播过程以及神经网络的参数
三层全连接网络结构,通过加入隐藏层实现了多层网络结构。
mnist_train.py定义了神经网络的训练过程
运行mnist_train.py文件便会开始训练模型,MNIST_model文件已经有训练好的模型,你也可以删掉或修改然后重新训练。
mnist_eval.py文件定义了测试过程
运行mnist_eval.py文件就是计算在mnist数据集上测试1万张图片的正确率。
app.py文件实现了测试自己手写数字的图片
在picture文件夹保存要测试的图片,运行app.py文件即可输出测试结果
以上都是我测试过的文件,可以使用,下面记录了我遇到的困难
我遇到的困难
困难1:不知道模型训练好了怎么测试自己手写的图片
跟着书上开始做的时候,mnist_inference.py、mnist_train.py 和 mnist_eval.py这三个文件已经可以实现训练模型和测试正确率。但是由于这三个文件使用的mnist数据集,我还不知道数据集到底长什么样子,也不知道模型训练好了怎么测试自己手写的图片。就是感觉好像做了一个很厉害的东西,但是不知道怎么用的感觉。
然后把模型搭建的过程重新梳理了一遍,才知道从何下手。输入节点是长度为784的数组,所以得把我的图片转化为长度是784的数组,才能输入到模型里,才能得到结果。
代码请看app.py里的image_prepare()函数,通过使用图像处理库PIL把图片转化为灰度图并且修改尺寸为28*28,然后转化为数组。
困难2:测试自己手写图片的正确率太低
在mnist测试数据集上的正确率有98.52%,而测试自己手写数字的正确率几乎为0,大部分数字都被识别成8。
通过不断测试我总结了以下几个原因:
- mnist数据集图片是黑底白字,而我们平时都是白底黑字,所以要对测试图片灰度反转。
修改过后测试如图
已经可以识别几个数字了,但还是很多被识别成了8。原因是自己拍的图片有很多噪点,直接输入给模型就因为噪点太多,被误认为是8。 - 二值化来降噪
原图
使用opencv二值化图像cv2.threshold(img,127,255,cv2.THRESH_BINARY)
虽然还有少量噪点,但已经有很好的识别效果了。(还可以再调整阈值)
测试结果
基本上识别了所有数字。
(测试自己写的数字时很可能因为输入图像没处理好导致识别率太低。)
我使用的版本
python3.7
tensorflow 1.13.1
运行时可能有很多warning,不影响运行结果
总结
本文只大致介绍了使用方法,神经网络的相关基础知识还要多多了解。这里使用的是基于全连接层网络结构的神经网络,对数字识别已经有了不错的效果,但使用卷积神经网络还可以提高正确率(大约99.2%),比如LeNet-5模型。
ZYB_BO: 注意空格
对方正在输入.....131: 老哥,tensorflow没有exzmple库了
Clearea: 我win10成功了。不过我之前配置过Jupyter的默认启动目录,就是那个默认打开的起始位置,不知和这个有没有关联。
Kiwi Star: 想问一下为什么输入的数字在输出时被当作一整个数字输出了啊
ironman5202: 回车过后需要继续输入数组呢