TensorFlow实现迁移学习

  • Post author:
  • Post category:其他


preprocess.py

import glob
import os.path
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile

input_data = './flower_photos'
output_file = './flower_processed_data.npy'

validation_percentage = 10
test_percentage = 10

def create_image_list(sess,test_percentage,validation_percentage):
    #返回三个值:根目录,子目录,根目录下的文件,若子目录非空,则对子目录递归调用
    sub_dirs = [x[0] for x in os.walk(input_data)]
    is_root_dir = True

    train_images = []
    train_labels = []
    test_images = []
    test_labels = []
    validation_images = []
    validation_labels = []

    current_label = 0

    for sub_dir in sub_dirs:
        #跳过根目录
        if is_root_dir:
            is_root_dir = False
            continue

        extensions = ['jpg','jpeg','JPG','JPEG']
        file_list = []
        #若a/b/c,那么basename就是c
        dir_name = os.path.basename(sub_dir)
        for extension in extensions:
            file_glob = os.path.join(input_data,dir_name,'*.'+extension)
            file_list.extend(glob.glob(file_glob))
            if not file_list:
                continue

            for file_name in file_list:
                print(file_name)
                #将图片转化为299*299*3的格式
                imamge_raw_data = gfile.FastGFile(file_name,'rb').read()
                image = tf.image.decode_jpeg(imamge_raw_data)
                if image.dtype != tf.float32:
                    image = tf.image.convert_image_dtype(image,tf.float32)
                image = tf.image.resize_images(image,[299,299])
                image_value = sess.run(image)

                chance = np.random.randint(100)
                if chance < validation_percentage:
                    validation_images.append(image_value)
                    validation_labels.append(current_label)
                elif chance < validation_percentage + test_percentage:
                    test_images.append(image_value)
                    test_labels.append(current_label)
                else:
                    train_images.append(image_value)
                    train_labels.append(current_label)

            current_label+=1

        state = np.random.get_state()
        np.random.shuffle(train_images)
        np.random.set_state(state)
        np.random.shuffle(train_labels)

        return np.asarray([
            train_images,train_labels,
            validation_images,validation_labels,
            test_images,test_labels
        ])

def main():
    with tf.Session() as sess:
        processed_data = create_image_list(sess,test_percentage,validation_percentage)
        np.save(output_file,processed_data)

if __name__ == '__main__':
    main()

finetune.py

import glob
import os.path
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile
import tensorflow.contrib.slim as slim

import tensorflow.contrib.slim.python.slim.nets.inception_v3 as inception_v3

input_data = 'flower_processed_data.npy'
train_file = 'model'
ckpt_file = './checkpoints/inception_v3.ckpt'

learning_rate = 1e-4
steps = 300
batch = 32
n_classes = 5

checkpoint_exclude_scopes = 'InceptionV3/Logits,InceptionV3/AuxLogits'
trainable_scopes = 'InceptionV3/Logits,InceptionV3/AuxLogits'

def get_tuned_variables():
    exclusions = [scope.strip() for scope in checkpoint_exclude_scopes.split(',')]

    variable_to_restore = []
    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
                break
        if not excluded:
            variable_to_restore.append(var)

    return variable_to_restore

def get_trainable_variables():
    scopes = [scope.strip() for scope in trainable_scopes.split(',')]

    variables_to_train = []
    for scope in scopes:
        #根据前缀名找到相关可训练的变量
        variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope)
        variables_to_train.extend(variables)
    return variables_to_train

def main(argv=None):
    #加载数据
    processed_data = np.load(input_data)
    train_images = processed_data[0]
    n_train_examples = len(train_images)
    train_labels = processed_data[1]
    validation_images = processed_data[2]
    validation_labels = processed_data[3]
    test_images = processed_data[4]
    test_labels = processed_data[5]

    print('train: {},test: {},validation: {}'.format(n_train_examples,len(test_images),len(validation_images)))

    images = tf.placeholder(tf.float32,[None,299,299,3],name='input-images')
    labels = tf.placeholder(tf.int64,[None],name='labels')

    #inception模型
    with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
        logits,_ = inception_v3.inception_v3(images,n_classes)

    trainable_varivables = get_trainable_variables()

    #tf.losses管理损失集合
    tf.losses.softmax_cross_entropy(tf.one_hot(labels,n_classes),logits,weights=1.0)

    train_step = tf.train.RMSPropOptimizer(learning_rate).minimize(tf.losses.get_total_loss())

    with tf.name_scope('evaluation'):
        correct_precision = tf.equal(tf.argmax(logits,1),labels)
        accuracy = tf.reduce_mean(tf.cast(correct_precision,tf.float32))

    #加载器
    load_fn = slim.assign_from_checkpoint_fn(ckpt_file,get_tuned_variables(),ignore_missing_vars=True)

    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        print('loading tuned variables from : '+ckpt_file)

        load_fn(sess)

        start = 0
        end = batch

        for i in range(steps):
            sess.run(train_step,feed_dict={
                images:train_images[start:end],
                labels:train_labels[start:end]
            })
            if i % 30 == 0 or i + 1 == steps:
                saver.save(sess,train_file,global_step=i)
                acc = sess.run(accuracy,feed_dict={images:validation_images,labels:validation_labels})
                print('{} steps, accuracy: {:.4f}'.format(i,acc))
            start = end
            if start == n_train_examples:
                start = 0

            end = start + batch
            if end > n_train_examples:
                end = n_train_examples

        test_acc = sess.run(accuracy,feed_dict={images:test_images,labels:test_labels})
        print('test accuracy: {}'.format(test_acc))

if __name__ == '__main__':
    tf.app.run()



版权声明:本文为a13602955218原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。