训练Inception模型提取特征

  • Post author:
  • Post category:其他




总体思路

(1)读取数据集并做预处理

(2)创建Inception模型

(3)训练模型前86层,不训练最后的全连接层

(4)设定最后的全连接层输出并训练



tips

(1)读取数据集并做预处理

training_datagen = ImageDataGenerator(
    rescale=1. / 255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
    )

(2)创建Inception模型

pre_trained_model = InceptionV3(input_shape = (224, 224, 3), 
                  include_top = False, 
                  weights = 'imagenet')

参数

include_top

是布尔型,False表示不包含顶部和最后的全连接层。

weights

有两个值,一个是None,不训练该模型。另一个是imagenet,表示在ImageNet上面预训练。

在这里插入图片描述

(3)训练模型前86层,不训练最后的全连接层

# 训练Inception模型中前86层,不训练最后的全连接层,因为我不需要最后的输出
for i in range(len(pre_trained_model.layers)):
  if i <= 86:
    pre_trained_model.layers[i].trainable = True
  else:
    pre_trained_model.layers[i].trainable = False
  print(i, pre_trained_model.layers[i].name, pre_trained_model.layers[i].trainable)

不做解释比较简单

(4)设定最后的全连接层输出并训练

x = layers.Flatten()(last_output)
x = layers.Dense(1024, activation='relu')(x)
x = layers.Dropout(0.2)(x)                  
x = layers.Dense  (5, activation='softmax')(x)    

最后的全连接层输出

history = model.fit(train_generator, 
            epochs=5, 
            # steps_per_epoch=1027/5, 
            validation_data = validation_generator, 
            # verbose = 2, 
            # validation_steps=288/5

注释掉的部分有待解释(?)

最后的训练结果:
在这里插入图片描述

两个问题,loss为什么这么大?val_accuracy为什么不变。

于是我决定不训练模型,只训练最后的全连接层。

for layer in pre_trained_model.layers:
  layer.trainable = False
  print(layer.name, layer.trainable)

把所有的层全部冻结。得到的结果是:

在这里插入图片描述

训练和验证集的精准度都很低,初步猜测是没有weights,所以添加一个weights。


local_weights_file = '/tmp/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5'
pre_trained_model = InceptionV3(input_shape = (224, 224, 3), 
                  include_top = False, 
                  weights = None)

结果如下:

在这里插入图片描述

效果差不多。于是我把蚊虫识别的模型放进去试试看效果。

在这里插入图片描述

最后得出的是:这种做法是错误的!

后来查了一些论文,可以用下面的方法做:

在这里插入图片描述

思路就是用’imagenet’来预训练模型,然后冻结所有层,在模型后端加上两个Dropout,三个全连接层(relu激活),然后用蚊虫的数据集来训练!

在这里插入图片描述

结果如上图,可以看到准确度可以达到70%,具体代码实现如下:

# 创建Inception预训练模型
from tensorflow.keras.applications.inception_v3 import InceptionV3

pre_trained_model = InceptionV3(input_shape = (224, 224, 3), 
                  include_top = False, 
                  weights = 'imagenet')
print('Model loaded.')

# 冻结所有层
for i in range(len(pre_trained_model.layers)):  
  pre_trained_model.layers[i].trainable = False
  print(i, pre_trained_model.layers[i].name, pre_trained_model.layers[i].trainable)

last_layer = pre_trained_model.output

x = layers.Flatten(x)
x = layers.Dropout(0.5)(x)  
x = layers.Dropout(0.5)(x)
x = layers.Dense(1024, activation='relu')(x)
x = layers.Dense(1024, activation='relu')(x)
x = layers.Dense(1024, activation='relu')(x)
x = layers.Dense(5, activation='softmax')(x) 

# 将这几层连接到pre_trained_model上面
model = Model(pre_trained_model.input, x)

model.compile(
    optimizer=SGD(lr=0.0001, momentum=0.9),               
    loss = 'categorical_crossentropy', 
    metrics = ['accuracy']
    )

history = model.fit(train_generator,
            epochs=10,
            # steps_per_epoch=1027/5,
            validation_data = validation_generator,
            # verbose = 2,
            # validation_steps=288/5
)

虽然训练结果有70%,但是还有其他的改进方法,尝试用SE模块加强特征表征能力。参考文献:Squeeze-and-Excitation Networks。

SE模块的代码如下:

# Attention
def squeeze_excitation_layer(x, out_dim):
    '''
    SE module performs inter-channel weighting.
    '''
    squeeze = GlobalAveragePooling2D()(x) # Flatten

    excitation = Dense(units=out_dim // 4)(squeeze)
    excitation = Activation('relu')(excitation)
    excitation = Dense(units=out_dim)(excitation)
    excitation = Activation('sigmoid')(excitation)
    excitation = Reshape((1, 1, out_dim))(excitation)

    scale = multiply([x, excitation])

    return scale

这个out_dim指的是卷积后的通道数

在这里插入图片描述

本文中是2048.

last_layer = pre_trained_model.output

x = squeeze_excitation_layer(last_layer, 2048)
x = layers.Flatten()(last_layer)

对模型最后的output做SE操作,然后flatten,接上上面的两个Dropout和三层全连接层,输出:

在这里插入图片描述

可以看到准确度还是有 明显提升的,但是准确度的评判并不充足,还是要用混淆矩阵。



版权声明:本文为qq_44705887原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。