TensorFlow使用Keras Tuner自动调参
代码地址:
https://github.com/lilihongjava/deep_learning/tree/master/TensorFlow2.0%E8%87%AA%E5%8A%A8%E8%B0%83%E5%8F%82
数据集
Zalando商品图片数据集,通过load_data函数读取data目录下 ‘train-labels-idx1-ubyte.gz’, ‘train-images-idx3-ubyte.gz’, ‘t10k-labels-idx1-ubyte.gz’, ‘t10k-images-idx3-ubyte.gz’文件
def load_data():
path = "./data/"
files = [
'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'
]
paths = [path + each for each in files]
with gzip.open(paths[0], 'rb') as lbpath:
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8) # uint8无符号整数(0 to 255),一个字节,一张图片256色
with gzip.open(paths[1], 'rb') as imgpath:
x_train = np.frombuffer(
imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28) # 图像尺寸(28*28)
with gzip.open(paths[2], 'rb') as lbpath:
y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8) # offset=8,前8不读
with gzip.open(paths[3], 'rb') as imgpath:
x_test = np.frombuffer(
imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)
return (x_train, y_train), (x_test, y_test)
(img_train, label_train), (img_test, label_test) = load_data()
归一化
img_train = img_train.astype('float32') / 255.0
img_test = img_test.astype('float32') / 255.0
图像分类模型
hypermodel
调整第一个Dense层中的层数,在32-512之间选择一个最佳值
hp.Int('units', min_value=32, max_value=512, step=32)
调整优化器的学习速率,从0.01、0.001或0.0001中选择一个最佳值
hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])
def model_builder(hp):
model = keras.Sequential()
model.add(keras.layers.Flatten(input_shape=(28, 28))) # 输入“压平”,即把多维的输入一维化
# Tune the number of units in the first Dense layer
# Choose an optimal value between 32-512
hp_units = hp.Int('units', min_value=32, max_value=512, step=32)
model.add(keras.layers.Dense(units=hp_units, activation='relu'))
model.add(keras.layers.Dense(10))
# Tune the learning rate for the optimizer
# Choose an optimal value from 0.01, 0.001, or 0.0001
hp_learning_rate = hp.Choice('learning_rate', values=[1e-2, 1e-3, 1e-4])
model.compile(optimizer=keras.optimizers.Adam(learning_rate=hp_learning_rate),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']) # accuracy,用于判断模型效果的函数
return model
Hyperband
使用Hyperband 算法搜索超参数
定义Hyperband,指定hypermodel,优化的目标,最大迭代次数,衰减系数,详细日志和checkpoints保存路径
tuner = kt.Hyperband(model_builder,
objective='val_accuracy', # 优化的目标,验证集accuracy
max_epochs=10, # 最大迭代次数
factor=3,
directory='my_dir', # my_dir/intro_to_kt目录包含超参数搜索期间运行的详细日志和checkpoints
project_name='intro_to_kt')
运行超参数搜索(自动调参)
ClearTrainingOutput为回调函数,在每个训练步骤结束时回调
tuner.search(img_train, label_train, epochs=10, validation_data=(img_test, label_test),
callbacks=[ClearTrainingOutput()])
获取最佳超参数
tuner.get_best_hyperparameters(num_trials=1)[0]
使用最佳超参数构建和训练模型
model = tuner.hypermodel.build(best_hps)
model.fit(img_train, label_train, epochs=10, validation_data=(img_test, label_test))
整体代码
if __name__ == '__main__':
# Zalando商品图片数据集
(img_train, label_train), (img_test, label_test) = load_data()
# 归一化
img_train = img_train.astype('float32') / 255.0
img_test = img_test.astype('float32') / 255.0
# 使用 Hyperband 算法搜索超参数
tuner = kt.Hyperband(model_builder,
objective='val_accuracy', # 优化的目标,验证集accuracy
max_epochs=10, # 最大迭代次数
factor=3,
directory='my_dir', # my_dir/intro_to_kt目录包含超参数搜索期间运行的详细日志和checkpoints
project_name='intro_to_kt')
tuner.search(img_train, label_train, epochs=10, validation_data=(img_test, label_test),
callbacks=[ClearTrainingOutput()])
# Get the optimal hyperparameters
best_hps = tuner.get_best_hyperparameters(num_trials=1)[0]
print(f"""
The hyperparameter search is complete. The optimal number of units in the first densely-connected
layer is {best_hps.get('units')} and the optimal learning rate for the optimizer
is {best_hps.get('learning_rate')}.
""")
# Build the model with the optimal hyperparameters and train it on the data
model = tuner.hypermodel.build(best_hps)
model.fit(img_train, label_train, epochs=10, validation_data=(img_test, label_test))
参考:https://www.tensorflow.org/tutorials/keras/keras_tuner
版权声明:本文为qq_33873431原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。