1、bert4keras bert模型微调再训练保存
参考:https://github.com/bojone/oppo-text-match/issues/5
## build_transformer_model为自定义bert模型层, keras.models.Model为创建整体keras模型
bert = build_transformer_model(
config_path,
checkpoint_path,
with_pool='linear',
application='unilm',
keep_tokens=keep_tokens, # 只保留keep_tokens中的字,精简原字表
return_keras_model=False,
)
encoder = keras.models.Model(bert.model.inputs, bert.model.outputs[0])
seq2seq = keras.models.Model(bert.model.inputs, bert.model.outputs[1])
outputs = TotalLoss([2, 3])(bert.model.inputs + bert.model.outputs)
model = keras.models.Model(bert.model.inputs, outputs)
AdamW = extend_with_weight_decay(Adam, 'AdamW')
optimizer = AdamW(learning_rate=2e-6, weight_decay_rate=0.01)
model.compile(optimizer=optimizer)
model.summary()
## 保存训练的bert模型
bert = build_transformer_model(xxx, return_keras_model=False)
model1 = bert.model
model1.load_weights(xxxxxx)
bert.save_weights_as_checkpoint(xxxxx)
## 训练后直接保存
model.fit_generator(
train_generator.forfit(),
steps_per_epoch=steps_per_epoch,
epochs=epochs,
# callbacks=[evaluator]
)
bert.save_weights_as_checkpoint("***.ckpt")
***完全重开时预训练参考:https://github.com/bojone/bert4keras/tree/master/pretraining
2、ckpt模型结构读取tensorboard展示
参考:https://blog.csdn.net/Mmagic1/article/details/106071818
import tensorflow as tf
from tensorflow.summary import FileWriter
sess = tf.Session()
tf.train.import_meta_graph("./model.ckpt.meta")
FileWriter("__tb", sess.graph)
生成:
在__tb文件夹中,按住shitf键,右键打开powershell;在powershell中输入:tensorboard --logdir=./ --host=127.0.0.15
版权声明:本文为weixin_42357472原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。