Keras在训练过程中根据epoch的值更换loss function

  • Post author:
  • Post category:其他




不停止训练根据epoch值更换loss function

以大于50epoch时将binary cross-entropy loss转换为focal loss为例。



1. 代码实现


Net Module

class Net(object):
    def __init__(self,epochs,img_rows,img_cols):
    	# initialize your parameters
        self.epochs = epochs
        self.img_rows = img_rows
        self.img_cols = img_cols
    def network_architecture(self):
    	# define your network architecture
    	inputs = Input(shape=(self.img_rows, self.img_cols),name='data')
    	outputs = Dense(10)
    	model = Model(inputs = [inputs], outputs = [outputs])
    	# initialize current_epoch which will be used in next training loops
    	self.current_epoch = K.variable(value=0)
    	# compile the model
    	model.compile(optimizer = Adam(learning_rate = 1e-4), loss = bce_focal_loss_consequence(change_epoch=50,current_epoch=self.current_epoch,gamma=2.,alpha=.25))
    	return model


Callback Module

class WarmUpCallback(Callback):
    def __init__(self,current_epoch):
        self.current_epoch = current_epoch

    def on_epoch_end(self, epoch, logs=None):
        K.set_value(self.current_epoch, epoch+1) 


Custom Loss Module

def bce_focal_loss_consequence(change_epoch,current_epoch,gamma,alpha):
    def bce_focal(y_true,y_pred):
        bool_case_1 = K.less(current_epoch,change_epoch)
        if bool_case_1:
            loss = binary_crossentropy(y_true,y_pred)
        else:
            loss = binary_focal_loss(gamma,alpha)(y_true,y_pred)
        return loss
    return bce_focal


Fit Module

my_callbacks = [WarmUpCallback(current_epoch=self.current_epoch)]
model.fit(training_generator, epochs=self.epochs,
			validation_data=validation_generator,callbacks=my_callbacks)



2. Loss Curve

在这里插入图片描述

如上图所示,如我们预期,Loss在50epoch时由于换了loss function而发生了突变,接着继续训练。



References:

https://github.com/keras-team/keras/issues/2595

https://stackoverflow.com/questions/42787181/variationnal-auto-encoder-implementing-warm-up-in-keras



版权声明:本文为kouwang9779原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。