本文将一个完整的tf2.0框架下使用CNN模型解决图像分类问题 喜欢记得关注我 点收藏不迷路 辛苦整理免费分享的
import glob
import os
import cv2
import numpy as np
import random
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import losses,layers,optimizers
from tensorflow.keras.callbacks import EarlyStopping
tf.random.set_seed(2222)
np.random.seed(2222)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
assert tf.__version__.startswith('2.')
def Data_Generation():
X_data=[];Y_data=[]
path_data=[];path_label=[]
#path_file=os.getcwd()
files=os.listdir('pokemon')
for file in files:
print(file)
for path in glob.glob('pokemon/'+file+'/*.*'):
if 'jpg' or 'png' or 'jpeg' in path:
path_data.append(path)
random.shuffle (path_data) #打乱数据
for paths in path_data:
if 'bulbasaur' in paths:
path_label.append(0)
elif 'charmander' in paths:
path_label.append(1)
elif 'mewtwo' in paths:
path_label.append(2)
elif 'pikachu' in paths:
path_label.append(3)
elif 'squirtle' in paths:
path_label.append(4)
img=cv2.imread(paths)
img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
img=cv2.resize(img,(224,224))
X_data.append(img)
L=len(path_data)
Y_data=path_label
X_data=np.array(X_data,dtype=float)
Y_data=np.array(Y_data,dtype='uint8')
X_train=X_data[0:int(L*0.6)]
Y_train=Y_data[0:int(L*0.6)]
X_valid=X_data[int(L*0.6):int(L*0.8)]
Y_valid=Y_data[int(L*0.6):int(L*0.8)]
X_test=X_data[int(L*0.8):]
Y_test=Y_data[int(L*0.8):]
return X_train,Y_train,X_valid,Y_valid,X_test,Y_test,L
def normalize(x):
img_mean = tf.constant([0.485, 0.456, 0.406])
img_std = tf.constant([0.229, 0.224, 0.225])
x = (x - img_mean)/img_std
return x
def preprocess(x,y):
x=tf.image.resize(x,[244,244])
x=tf.image.random_flip_left_right(x)
x=tf.image.random_crop(x,[224,224,3])
x = tf.cast(x, dtype=tf.float32) / 255.
x = normalize(x)
y = tf.convert_to_tensor(y)
y = tf.one_hot(y, depth=5)
return x,y
X_train,Y_train,X_valid,Y_valid,X_test,Y_test,L=Data_Generation()
batchsz=32
#print(shape(X_data), shape(Y_data))
train_db = tf.data.Dataset.from_tensor_slices((X_train,Y_train))
train_db = train_db.shuffle(10000).map(preprocess).batch(batchsz)
valid_db = tf.data.Dataset.from_tensor_slices((X_valid,Y_valid))
valid_db = valid_db.map(preprocess).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices((X_test,Y_test))
test_db = test_db.map(preprocess).batch(batchsz)
net=keras.applications.DenseNet121(weights='imagenet',include_top=False,pooling='max')#这里使用了自带的DenseNet121网络 你也可以用keras.Sequential DIY模型
net.trainable=False
mynet=keras.Sequential([
net,
layers.Dense(1024,activation='relu'),
layers.BatchNormalization(), #BN层 标准化数据
layers.Dropout(rate=0.2),
layers.Dense(5)])
mynet.build(input_shape=(4,224,224,3))
mynet.summary()
early_stopping=EarlyStopping( #防止过拟合
monitor='val_accuracy',
min_delta=0.01,
patience=3)
mynet.compile(optimizer=optimizers.Adam(lr=1e-3),
loss=losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
history = mynet.fit(train_db, validation_data=valid_db, validation_freq=1, epochs=50,
callbacks=[early_stopping])
history = history.history
mynet.evaluate(test_db)
#训练结束以后保存mmodel文件到本地方便做图片分类的时候直接调用
#way1 保存model成 .pb 格式 方便各个平台(移动端等)的调用
tf.saved_model.save(mynet,'densenet')
#way2 保存model成 .h5格式 里面包含了模型结构和训练好的模型参数
mynet.save('densenet.h5')
因为设置了随机种子seed 所以每次的train validation test 集都一样,如果结果不满意的话 代码里有很多超参数可以调
程序运行完成后就可以在代码所在的文件夹里发现多了一个densenet文件夹 打开之后如下图,pb后缀的文件就是我们保存的模型文件
也可以生成 .h5 文件
生成模型后就可以用在网上下载图片来测试了,这里我用了电影大侦探皮卡丘的一张图做测试 pika pika~
import tensorflow as tf
from tensorflow import keras
import cv2
label=['bulbasaur','charmander','mewtwo','pikachu','squirtle' ]
network = keras.models.load_model('densenet.h5')
network.summary()
image=cv2.imread('test.jpeg')
img=image.copy()
img=cv2.resize(img,(224,224))
img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
def normalize(x):
img_mean = tf.constant([0.485, 0.456, 0.406])
img_std = tf.constant([0.229, 0.224, 0.225])
x = (x - img_mean)/img_std
return x
def preprocess(x):
x = tf.expand_dims(x,axis=0)
x = tf.cast(x, dtype=tf.float32) / 255.
x = normalize(x)
return x
img=preprocess(img)
#img= tf.cast(img, dtype=tf.uint8)
result=network(img)
result=tf.nn.softmax(result)
index=tf.argmax(result,axis=-1)
print(label[int(index)])
cv2.putText(image,label[int(index)],(166,54),cv2.FONT_HERSHEY_SCRIPT_SIMPLEX, 1.2, (255,0,0),2)
cv2.imshow('img',image)
cv2.waitKey()
cv2.destroyAllWindows()
运行代码后,就可以看到如下的效果了,图片的上方多了所属类的标签
版权声明:本文为Oscarouyangyafei原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。