Tensorflow2.x 中自定义回调函数
在模型训练的时候 我们经常需要判断当我们模型训练的
acc
或者
loss
到达一个值的时候结束我们模型的训练,因此我们就需要回调函数这个功能。
Tensorflow2.x 回调机制
tf.keras的回调函数实际上是一个类,一般是在model.fit时作为参数指定,用于控制在训练过程开始或者在训练过程结束,在每个epoch训练开始或者训练结束,在每个batch训练开始或者训练结束时执行一些操作,例如收集一些日志信息,改变学习率等超参数,提前终止训练过程等等。
同样地,针对model.evaluate或者model.predict也可以指定callbacks参数,用于控制在评估或预测开始或者结束时,在每个batch开始或者结束时执行一些操作,但这种用法相对少见。
大部分时候,keras.callbacks子模块中定义的回调函数类已经足够使用了,如果有特定的需要,我们也可以通过对keras.callbacks.Callbacks实施子类化构造自定义的回调函数。
所有回调函数都继承至 keras.callbacks.Callbacks基类,拥有params和model这两个属性。
其中params 是一个dict,记录了 training parameters (eg. verbosity, batch size, number of epochs…).
model即当前关联的模型的引用。
此外,对于回调类中的一些方法如on_epoch_begin,on_batch_end,还会有一个输入参数logs, 提供有关当前epoch或者batch的一些信息,并能够记录计算结果,如果model.fit指定了多个回调函数类,这些logs变量将在这些回调函数类的同名函数间依顺序传递。
Tensorflow2.x 内置回调函数
-
BaseLogger: 收集每个epoch上metrics在各个batch上的平均值,对stateful_metrics参数中的带中间状态的指标直接拿最终值无需对各个batch平均,指标均值结果将添加到logs变量中。该回调函数被所有模型默认添加,且是第一个被添加的。
-
History: 将BaseLogger计算的各个epoch的metrics结果记录到history这个dict变量中,并作为model.fit的返回值。该回调函数被所有模型默认添加,在BaseLogger之后被添加。
-
EarlyStopping: 当被监控指标在设定的若干个epoch后没有提升,则提前终止训练。
-
TensorBoard: 为Tensorboard可视化保存日志信息。支持评估指标,计算图,模型参数等的可视化。
-
ModelCheckpoint: 在每个epoch后保存模型。
-
ReduceLROnPlateau:如果监控指标在设定的若干个epoch后没有提升,则以一定的因子减少学习率。
-
TerminateOnNaN:如果遇到loss为NaN,提前终止训练。
-
LearningRateScheduler:学习率控制器。给定学习率lr和epoch的函数关系,根据该函数关系在每个epoch前调整学习率。
-
CSVLogger:将每个epoch后的logs结果记录到CSV文件中。
-
ProgbarLogger:将每个epoch后的logs结果打印到标准输出流中
Tensorflow2.x 自定义回调函数
在这里我们自己定义我们自己的回调函数
import tensorflow as tf
class myCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}):
if(logs.get('accuracy')>0.6):
print("\nReached 60% accuracy so cancelling training!")
self.model.stop_training = True
mnist = tf.keras.datasets.fashion_mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
callbacks = myCallback()
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])
测试结果
Epoch 1/10
1864/1875 [============================>.] - ETA: 0s - loss: 0.4724 - accuracy: 0.8312
Reached 60% accuracy so cancelling training!
1875/1875 [==============================] - 2s 1ms/step - loss: 0.4725 - accuracy: 0.8312
<tensorflow.python.keras.callbacks.History at 0x1c75d2b5518>