mnist_cnn训练保存模型然后去识别手写数字

  • Post author:
  • Post category:其他


mnist是很多人入门机器/深度学习的入门数据集,但是只是用来测试模型和入门学习,而忽略了mnist是一个非常好的数字识别的库。

那么我使用一个非常简单,大概5-6层卷积+池化再加几层全连接的结构来训练一下mnist,然后保存下模型,当我想识别一个字符的时候就可以直接读取这个模型,然后识别这个字符了。

首先是网络模型

net =slim.repeat(net,1,slim.conv2d, 32, [3, 3], scope = 'conv1')
net = slim.max_pool2d(net,[3,3],scope ='pool1',stride = 2)
'''
14*14*32
'''
net = slim.repeat(net, 1, slim.conv2d, 64, [3, 3], scope='conv2')
net = slim.max_pool2d(net, [3, 3], scope='pool2',stride = 2)
'''
7*7*64
'''
net = slim.repeat(net, 1, slim.conv2d, 128, [3, 3], scope='conv3')
net = slim.max_pool2d(net, [3, 3], scope='pool3',stride = 2,padding="VALID")
'''
4*4*128
'''
net = slim.repeat(net, 1, slim.conv2d, 256, [3, 3], scope='conv4')
'''
4*4*256
'''
net = slim.flatten(net, scope='flatten')
net = slim.dropout(net, keep_prob=0.8,
                   is_training=self._is_training)
net = slim.fully_connected(net, 1024, scope='fc1')
net = slim.fully_connected(net, 64, scope='fc2')
net = slim.fully_connected(net, self.num_classes,
                           activation_fn=None, scope='fc3')

然后定义输入的张量的shape是[None,784],标签是[None],然后将这个输入的tensor转化一下shape,转化成可以进行卷积操作的shape

inputs = tf.placeholder(tf.float32, shape=[None, 784], name='inputs')
labels = tf.placeholder(tf.int32, shape=[None], name='labels')
cls_model = model_mnist.Model(is_training=True, num_classes=10)
image = tf.reshape(inputs,[-1,28,28,1])

然后识别的时候将图片转化为[1,784]的格式,一次识别一张的话。

import numpy as np
import tensorflow as tf
import cv2
import os
import time

model_ckpt_path = "D:/all_model/mnist_model/model.ckpt"

def main(_):
    with tf.Session() as sess:
        ckpt_path = model_ckpt_path
        saver = tf.train.import_meta_graph(ckpt_path + '.meta')
        saver.restore(sess, ckpt_path)
        inputs = tf.get_default_graph().get_tensor_by_name('inputs:0')
        classes = tf.get_default_graph().get_tensor_by_name('classes:0')
        image = cv2.imread("D:/5.jpg", cv2.IMREAD_GRAYSCALE)
        image = cv2.resize(image, (28, 28))
        image_np = np.resize(image,[1,784])
        predicted_label = sess.run(classes, feed_dict={inputs: image_np})
        print(predicted_label)
if __name__ == '__main__':
    tf.app.run()



版权声明:本文为bomingzi原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。