preprocess.py
import glob
import os.path
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile
input_data = './flower_photos'
output_file = './flower_processed_data.npy'
validation_percentage = 10
test_percentage = 10
def create_image_list(sess,test_percentage,validation_percentage):
#返回三个值:根目录,子目录,根目录下的文件,若子目录非空,则对子目录递归调用
sub_dirs = [x[0] for x in os.walk(input_data)]
is_root_dir = True
train_images = []
train_labels = []
test_images = []
test_labels = []
validation_images = []
validation_labels = []
current_label = 0
for sub_dir in sub_dirs:
#跳过根目录
if is_root_dir:
is_root_dir = False
continue
extensions = ['jpg','jpeg','JPG','JPEG']
file_list = []
#若a/b/c,那么basename就是c
dir_name = os.path.basename(sub_dir)
for extension in extensions:
file_glob = os.path.join(input_data,dir_name,'*.'+extension)
file_list.extend(glob.glob(file_glob))
if not file_list:
continue
for file_name in file_list:
print(file_name)
#将图片转化为299*299*3的格式
imamge_raw_data = gfile.FastGFile(file_name,'rb').read()
image = tf.image.decode_jpeg(imamge_raw_data)
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image,tf.float32)
image = tf.image.resize_images(image,[299,299])
image_value = sess.run(image)
chance = np.random.randint(100)
if chance < validation_percentage:
validation_images.append(image_value)
validation_labels.append(current_label)
elif chance < validation_percentage + test_percentage:
test_images.append(image_value)
test_labels.append(current_label)
else:
train_images.append(image_value)
train_labels.append(current_label)
current_label+=1
state = np.random.get_state()
np.random.shuffle(train_images)
np.random.set_state(state)
np.random.shuffle(train_labels)
return np.asarray([
train_images,train_labels,
validation_images,validation_labels,
test_images,test_labels
])
def main():
with tf.Session() as sess:
processed_data = create_image_list(sess,test_percentage,validation_percentage)
np.save(output_file,processed_data)
if __name__ == '__main__':
main()
finetune.py
import glob
import os.path
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile
import tensorflow.contrib.slim as slim
import tensorflow.contrib.slim.python.slim.nets.inception_v3 as inception_v3
input_data = 'flower_processed_data.npy'
train_file = 'model'
ckpt_file = './checkpoints/inception_v3.ckpt'
learning_rate = 1e-4
steps = 300
batch = 32
n_classes = 5
checkpoint_exclude_scopes = 'InceptionV3/Logits,InceptionV3/AuxLogits'
trainable_scopes = 'InceptionV3/Logits,InceptionV3/AuxLogits'
def get_tuned_variables():
exclusions = [scope.strip() for scope in checkpoint_exclude_scopes.split(',')]
variable_to_restore = []
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variable_to_restore.append(var)
return variable_to_restore
def get_trainable_variables():
scopes = [scope.strip() for scope in trainable_scopes.split(',')]
variables_to_train = []
for scope in scopes:
#根据前缀名找到相关可训练的变量
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope)
variables_to_train.extend(variables)
return variables_to_train
def main(argv=None):
#加载数据
processed_data = np.load(input_data)
train_images = processed_data[0]
n_train_examples = len(train_images)
train_labels = processed_data[1]
validation_images = processed_data[2]
validation_labels = processed_data[3]
test_images = processed_data[4]
test_labels = processed_data[5]
print('train: {},test: {},validation: {}'.format(n_train_examples,len(test_images),len(validation_images)))
images = tf.placeholder(tf.float32,[None,299,299,3],name='input-images')
labels = tf.placeholder(tf.int64,[None],name='labels')
#inception模型
with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
logits,_ = inception_v3.inception_v3(images,n_classes)
trainable_varivables = get_trainable_variables()
#tf.losses管理损失集合
tf.losses.softmax_cross_entropy(tf.one_hot(labels,n_classes),logits,weights=1.0)
train_step = tf.train.RMSPropOptimizer(learning_rate).minimize(tf.losses.get_total_loss())
with tf.name_scope('evaluation'):
correct_precision = tf.equal(tf.argmax(logits,1),labels)
accuracy = tf.reduce_mean(tf.cast(correct_precision,tf.float32))
#加载器
load_fn = slim.assign_from_checkpoint_fn(ckpt_file,get_tuned_variables(),ignore_missing_vars=True)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print('loading tuned variables from : '+ckpt_file)
load_fn(sess)
start = 0
end = batch
for i in range(steps):
sess.run(train_step,feed_dict={
images:train_images[start:end],
labels:train_labels[start:end]
})
if i % 30 == 0 or i + 1 == steps:
saver.save(sess,train_file,global_step=i)
acc = sess.run(accuracy,feed_dict={images:validation_images,labels:validation_labels})
print('{} steps, accuracy: {:.4f}'.format(i,acc))
start = end
if start == n_train_examples:
start = 0
end = start + batch
if end > n_train_examples:
end = n_train_examples
test_acc = sess.run(accuracy,feed_dict={images:test_images,labels:test_labels})
print('test accuracy: {}'.format(test_acc))
if __name__ == '__main__':
tf.app.run()
版权声明:本文为a13602955218原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。