Kera的应用模块Application提供了带有预训练权重的Keras模型,这些模型可以用来进行预测、特征提取和finetune.
今早跑了第一个官方实例程序:利用ResNet50网络进行ImageNet分类。
测试图片【非洲象】:
ResNet50结构:
源码以及详细注释如下:
# -*- coding: UTF-8 -*-
#-------------------------------------------
#任 务:利用ResNet50网络进行ImageNet分类
#数 据:网上下载的测试图片‘elephant.jpg’
#-------------------------------------------
from keras.applications.resnet50 import ResNet50
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as np
import os
import tensorflow as tf
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
gpu_options = tf.GPUOptions(allow_growth=True)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
from keras.utils import plot_model
from matplotlib import pyplot as plt
# 【0】ResNet50模型,加载预训练权重
model = ResNet50(weights='imagenet')
print(model.summary()) # 打印模型概况
plot_model(model,to_file = 'a simple convnet.png') # 画出模型结构图,并保存成图片
# 【1】从网上下载一张图片,保存在当前路径下
img_path = './elephant.jpg'
img = image.load_img(img_path, target_size=(224, 224))
# 【2】显示图片
plt.imshow(img)
plt.show()
#【3】将图片转化为4d tensor形式
x = image.img_to_array(img)
print(x.shape) #(224, 224, 3)
x = np.expand_dims(x, axis=0)
print(x.shape) #(1, 224, 224, 3)
# 【4】数据预处理
"""
def preprocess_input(x, data_format=None, mode='caffe'):
Preprocesses a tensor or Numpy array encoding a batch of images.
# Arguments
x: Input Numpy or symbolic tensor, 3D or 4D.
data_format: Data format of the image tensor/array.
mode: One of "caffe", "tf".
- caffe: will convert the images from RGB to BGR,
then will zero-center each color channel with
respect to the ImageNet dataset,
without scaling.
- tf: will scale pixels between -1 and 1,
sample-wise.
# Returns
Preprocessed tensor or Numpy array.
# Raises
ValueError: In case of unknown `data_format` argument.
"""
x = preprocess_input(x) #去均值中心化,preprocess_input函数详细功能见注释
# 【5】测试数据
preds = model.predict(x)
print(preds.shape) # (1,1000)
# 【6】将测试结果解码为如下形式:
# [(class1, description1, prob1),(class2, description2, prob2)...]
print('Predicted:', decode_predictions(preds, top=3)[0])
#'Predicted:', [(u'n02504458', u'African_elephant', 0.8791098), (u'n02504013', u'Indian_elephant', 0.066597864), (u'n01871265', u'tusker', 0.054188617)]