LSTM/GRU/RNN/CNN手写数字分类【简洁,多种模式对比】

  • Post author:
  • Post category:其他





前言

环境搭建好以后代码复制到jupyter notebook可以运行。结合我的另外一篇博文《自定义数据集(X_train,y_train),(X_test,y_test)》,即可实现任意自定义数据集的分类。可运用在各个领域。




一、导入相关库

# 两种方法
import keras
import numpy as np
import pandas
from keras.utils import np_utils
from keras.datasets import mnist
from matplotlib import pyplot
from keras.models import Sequential
from keras.layers import Dense,Activation,Dropout,Convolution2D,MaxPool2D,Flatten
from keras.layers import SimpleRNN,LSTM,GRU,Embedding
from keras.layers import Dense, LSTM, Lambda, TimeDistributed, Input, Masking, Bidirectional
from keras.optimizers import SGD,Adam
import matplotlib.pyplot as plt



二、导入手写数字数据集

NUM_CLASS=10

# 载入数据
(x_train,y_train),(x_test,y_test)=mnist.load_data()

x_train1=x_train/255.0
x_test1=x_test/255.0


print('归一化以后======================')
print('='*100)
print('x_train1',x_train1.shape)
print('x_test1',x_test1.shape)
print('y_train',y_train.shape)
print('y_test',y_test.shape)
print('='*100)

x_train4d=x_train.reshape(-1,28,28,1)/255.0
x_test4d=x_test.reshape(-1,28,28,1)/255.0


print('转换维度以后======================')
print('x_train4d',x_train4d.shape)
print('x_test4d',x_test4d.shape)

 # 转化成one-hot格式
y_train2d=np_utils.to_categorical(y_train,num_classes=NUM_CLASS)
y_test2d=np_utils.to_categorical(y_test,num_classes=NUM_CLASS)


print('y_train2d',y_train2d.shape)
print('y_test2d',y_test2d.shape)

y_train3d=y_train2d.reshape(60000,10,1)
y_test3d=y_test2d.reshape(10000,10,1)

print('y_train3d',y_train3d.shape)
print('y_test3d',y_test3d.shape)



三、定义画loss曲线函数

def DrawLine(mhistory):
    
    plt.plot(mhistory.history['accuracy'])
    plt.plot(mhistory.history['val_accuracy'])
    plt.title('model accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.show()

    plt.plot(mhistory.history['loss'])
    plt.plot(mhistory.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.show()



四、选择模型运行

pMode="RNN"

if pMode=="lstm":
    print('运行模式是:','lstm')
    model = Sequential()
    model.add(LSTM(units=50, return_sequences=True,input_shape=(28, 28)))
    model.add(Dropout(0.5))
    model.add(LSTM(units=20, return_sequences=False,activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(NUM_CLASS, activation='softmax'))

if pMode=="Bilstm":
    print('运行模式是:','Bilstm')
    model = Sequential()
    model.add(Bidirectional(LSTM(50,input_shape=(28,28),return_sequences=True)))
    model.add(Dropout(0.5))
    model.add(Bidirectional(LSTM(25,return_sequences=False)))
    model.add(Dropout(0.5))
    model.add(Dense(NUM_CLASS,activation='softmax'))

if pMode=="GRU":
    print('运行模式是:','GRU')
    model = Sequential()
    model.add(GRU(units=50, return_sequences=True,input_shape=(28, 28)))
    model.add(GRU(units=50, return_sequences=True))
    model.add(Dropout(0.5))
    model.add(GRU(units=50, return_sequences=False))
    model.add(Dropout(0.5))
    model.add(Dense(NUM_CLASS, activation='softmax'))

if pMode=="RNN":
    print('运行模式是:','RNN')
    model=Sequential()
    model.add(SimpleRNN(units=100,return_sequences=True,input_shape=(28,28)))
    model.add(Dropout(0.5))
    model.add(SimpleRNN(units=50,return_sequences=False))
    model.add(Dropout(0.5))
    model.add(Dense(NUM_CLASS,activation='softmax'))

if pMode=="CNN":
    print('运行模式是:','CNN')
    model=Sequential()
    model.add(Convolution2D(input_shape=(28,28,1),filters=32,kernel_size=5,strides=1,padding='same',activation='relu'))
    model.add( MaxPool2D(pool_size=2, strides=2, padding='same'))
    model.add(Convolution2D(64,5,strides=1,padding='same',activation='relu'))
    model.add(MaxPool2D(2,2, 'same'))
    model.add(Flatten())
    model.add(Dense(1024,activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(NUM_CLASS,activation='softmax'))

model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
if pMode=="CNN":
    history=model.fit(x_train4d ,y_train2d,batch_size=64,epochs=3,validation_data=(x_test4d,y_test2d))
else:
    history=model.fit(x_train1,y_train2d,batch_size=64,epochs=2,validation_data=(x_test1,y_test2d))
model.summary()
DrawLine(history)



五、画出混淆矩阵

from sklearn.metrics import confusion_matrix

import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import itertools


labels=["1","2","3","4","5","6","7","8","9","10"]

def plot_confusion_matrix(cm,target_names,title='Confusion matrix',cmap=plt.cm.Greens,normalize=True):
    accuracy = np.trace(cm) / float(np.sum(cm))
    misclass = 1 - accuracy

    if cmap is None:
        cmap = plt.get_cmap('Blues')

    plt.figure(figsize=(15, 12))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()

    if target_names is not None:
        tick_marks = np.arange(len(target_names))
        plt.xticks(tick_marks, target_names, rotation=45)
        plt.yticks(tick_marks, target_names)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    thresh = cm.max() / 1.5 if normalize else cm.max() / 2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if normalize:
            plt.text(j, i, "{:0.4f}".format(cm[i, j]),horizontalalignment="center",color="white" if cm[i, j] > thresh else "black")
        else:
            plt.text(j, i, "{:,}".format(cm[i, j]),horizontalalignment="center",color="white" if cm[i, j] > thresh else "black")


    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass))
    #这里这个savefig是保存图片,如果想把图存在什么地方就改一下下面的路径,然后dpi设一下分辨率即可。
	#plt.savefig('/content/drive/My Drive/Colab Notebooks/confusionmatrix32.png',dpi=350)
    plt.show()
# 显示混淆矩阵
def plot_confuse(model, x_val, y_val):
    predictions = model.predict_classes(x_val,batch_size=64)
    truelabel = y_val.argmax(axis=-1)   # 将one-hot转化为label
    conf_mat = confusion_matrix(y_true=truelabel, y_pred=predictions)
    plt.figure()
    plot_confusion_matrix(conf_mat, normalize=False,target_names=labels,title='Confusion Matrix')
#=

plot_confuse(model,x_test1,y_test2d)



总结

目前keras=2.31,tf=1.15版本可以正常运行。有一定借鉴意义,请点赞收藏支持。



版权声明:本文为linmuquan1989原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。