一、MNIST数据集的特点
部分数据:
MNIST数据集是一个手写体数据集,数据集中每一个样本都是一个0-9的手写数字。该数据集由四部分组成,训练图片集,训练标签集,测试图片集和测试标签集。其中,训练集中有60000个样本,测试集中有10000个样本。每张照片均为28*28的二值图片,为方便存储,官方已对图片集进行处理,将每一张图片变成了维度为(1,784)的向量。
MNIST数据集特点有:
1、每张照片均为28*28的二值图片
2、数字的笔画无间断
3、数字在图像中倾斜的角度不超过45°
4、所有数字在图像中所占的比例相同
二、LeNet模型原理
(一)卷积神经网络
卷积神经网络是一种特殊的多层神经网络,像其它的神经网络一样,卷积神经网络也使用一种反向传播算法来进行训练,不同之处在于网络的结构。卷积神经网络的网络连接具有局部连接、参数共享的特点。局部连接是相对于普通神经网络的全连接而言的,是指这一层的某个节点只与上一层的部分节点相连。参数共享是指一层中多个节点的连接共享相同的一组参数。
对于一个卷积神经网络,假如该网络的第k层有n个节点,k+1层为卷积层且有m个节点,则k+1层的每个节点只与k层的部分节点相连,此处假设只与k层的i个节点相连(局部连接);另外k+1层的每个节点的连接共享相同的参数、相同的bias(参数共享)。这样该卷积神经网络的第k、k+1层间共有m*i个连接、i+1个参数。由于i小于n且为常数,所以卷积层的连接数、参数数量的数量级约为O(n),远小于全连接的O(n
2
)的数量级。卷积神经网络的部分连接的结构如下图:
(二)LeNet-5模型
LeNet-5是Yann LeCun在1998年设计的用于手写数字识别的卷积神经网络,是早期卷积神经网络中最有代表性的实验系统之一。
它的网路结构如下图:
图一 LeNet-5的网络结构
LeNet-5共有7层(不包含输入),每层都包含可训练参数。
首先,输入图像大小为32×32的灰度图,比MNIST数据集的图片要大一些,此在训练整个网络之前,需要对28×28的图像加上paddings(即周围填充0)。
在LeNet-5的网络结构中,C1层是一个卷积层。该层使用6个大小为5×5的卷积核对输入层进行卷积运算,产生6个大小为28×28的特征图。卷积层的作用是对输入图像提取6个特征。C1层的连接结构如图二。
图二 C1层的连接结构
S2层是一个池化层(也称为下采样层)。这里采用max_pool(最大池化),池化的size定为2×2。对C1层的每一个特征图的长宽尺寸降到原来的1/2,得到6个14×14的特征通道,且数量不变。池化层作用是降低网络训练参数及模型的过拟合程度。S2层的网络连接结构如图三:
图三 S2层的网络连接结构
C3层仍为一个卷积层,选用大小为5×5的16种不同的卷积核。C3当中的每个特征图,都是S2层中的所有或其中几个特征图进行加权组合得到的。输出为16个10措10的特征图。该层的作用是提取深层特征。
S4层仍为一个池化层,size为2×2,仍采用max_pool。最后输出16个5×5的特征图,神经元个数也减少至16×5×5=400个。
C5层继续用5*5的卷积核对S4层的输出进行卷积,卷积核数量增加至120。这样C5层的输出图片大小为5-5+1=1。最终输出120个1×1的特征图。这里实际上是与S4全连接了,但仍将其标为卷积层,原因是如果LeNet-5的输入图片尺寸变大,其他保持不变,那该层特征图的维数也会大于1×1。C5层的连接方式如图四。
图四 C5层的连接方式
F6层是全连接层,F有84个节点,该层输出84张特征图。84个神经元与C5中的120个神经元全连接,再加上4个偏置项。连接方式如图五。
图五 F6层连接方式
三、实验结果
(一)训练测试结果:
1、训练次数为20时:
(二)匹配结果:
-
开始页面展示
-
选择图片进行识别
-
当把图像进行旋转后
出现错误
旋转后的图片不是维度为(1,784)的向量,所以程序无法对其进行识别。
四、相关代码
1、窗口界面代码:
import tensorflow as tf
import numpy as np
import tkinter as tk
from tkinter import filedialog
from PIL import Image, ImageTk
from tkinter import filedialog
import time
def creat_windows():
win = tk.Tk() # 创建窗口
sw = win.winfo_screenwidth()
sh = win.winfo_screenheight()
ww, wh = 400, 450
x, y = (sw-ww)/2, (sh-wh)/2
win.geometry("%dx%d+%d+%d"%(ww, wh, x, y-40)) # 居中放置窗口
win.title('手写体识别') # 窗口命名
bg1_open = Image.open("timg.jpg").resize((300, 300))
bg1 = ImageTk.PhotoImage(bg1_open)
canvas = tk.Label(win, image=bg1)
canvas.pack()
var = tk.StringVar() # 创建变量文字
var.set('')
tk.Label(win, textvariable=var, bg='#C1FFC1', font=('宋体', 21), width=20, height=2).pack()
tk.Button(win, text='选择图片', width=20, height=2, bg='#FF8C00', command=lambda:main(var