不停止训练根据epoch值更换loss function
以大于50epoch时将binary cross-entropy loss转换为focal loss为例。
1. 代码实现
Net Module
class Net(object):
def __init__(self,epochs,img_rows,img_cols):
# initialize your parameters
self.epochs = epochs
self.img_rows = img_rows
self.img_cols = img_cols
def network_architecture(self):
# define your network architecture
inputs = Input(shape=(self.img_rows, self.img_cols),name='data')
outputs = Dense(10)
model = Model(inputs = [inputs], outputs = [outputs])
# initialize current_epoch which will be used in next training loops
self.current_epoch = K.variable(value=0)
# compile the model
model.compile(optimizer = Adam(learning_rate = 1e-4), loss = bce_focal_loss_consequence(change_epoch=50,current_epoch=self.current_epoch,gamma=2.,alpha=.25))
return model
Callback Module
class WarmUpCallback(Callback):
def __init__(self,current_epoch):
self.current_epoch = current_epoch
def on_epoch_end(self, epoch, logs=None):
K.set_value(self.current_epoch, epoch+1)
Custom Loss Module
def bce_focal_loss_consequence(change_epoch,current_epoch,gamma,alpha):
def bce_focal(y_true,y_pred):
bool_case_1 = K.less(current_epoch,change_epoch)
if bool_case_1:
loss = binary_crossentropy(y_true,y_pred)
else:
loss = binary_focal_loss(gamma,alpha)(y_true,y_pred)
return loss
return bce_focal
Fit Module
my_callbacks = [WarmUpCallback(current_epoch=self.current_epoch)]
model.fit(training_generator, epochs=self.epochs,
validation_data=validation_generator,callbacks=my_callbacks)
2. Loss Curve
如上图所示,如我们预期,Loss在50epoch时由于换了loss function而发生了突变,接着继续训练。
References:
https://github.com/keras-team/keras/issues/2595
https://stackoverflow.com/questions/42787181/variationnal-auto-encoder-implementing-warm-up-in-keras
版权声明:本文为kouwang9779原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。