python - 通过 HParams 和 Tensorboard 进行贝叶斯优化
问题描述
我目前正在使用 HParams 来发起一个网格搜索超参数优化会话,该会话运行良好,并且正在将日志输出到我的 tensorboard HParams 插件,我可以看到各种不同的运行和并行坐标视图。代码的结构是这样的,尽管可能没有必要针对这个问题对其进行审查:
def hparam_wrap(args, n_classes, train_dataset, val_dataset, tokenizer):
log_date_subfolder = time.strftime("%Y%m%d-%H%M%S")
hparams_dict={
'HP_EMBEDDING_NODES': hp.HParam('embedding_nodes', hp.Discrete([200,300])),
'HP_LSTM_NODES': hp.HParam('lstm_nodes', hp.Discrete([200,300])),
'HP_TIMEDIST_NODES': hp.HParam('timedist_nodes', hp.Discrete([200,300])),
'HP_NUM_DENSE_LAYERS': hp.HParam('num_dense_layers', hp.Discrete([3,4, 5])),
'HP_DENSE_NODES': hp.HParam('dense_nodes', hp.Discrete([300,400, 500])),
'HP_LEARNING_RATE': hp.HParam('learning_rate', hp.Discrete([0.001, 0.0001, 0.00001])),
'HP_DROPOUT': hp.HParam('dropout', hp.Discrete([0.3, 0.4,0.5, 0.6])),
'HP_BATCH_SIZE': hp.HParam('batch_size', hp.Discrete([96]))
}
session_num = 0
for en in hparams_dict['HP_EMBEDDING_NODES'].domain.values:
for ln in hparams_dict['HP_LSTM_NODES'].domain.values:
for td in hparams_dict['HP_TIMEDIST_NODES'].domain.values:
for dl in hparams_dict['HP_NUM_DENSE_LAYERS'].domain.values:
for dn in hparams_dict['HP_DENSE_NODES'].domain.values:
for lr in hparams_dict['HP_LEARNING_RATE'].domain.values:
for do in hparams_dict['HP_DROPOUT'].domain.values:
for bs in hparams_dict['HP_BATCH_SIZE'].domain.values:
hparams ={
'HP_NUM_DENSE_LAYERS': dl,
'HP_LEARNING_RATE': lr,
'HP_DROPOUT': do,
'HP_DENSE_NODES': dn,
'HP_BATCH_SIZE': bs,
'HP_EMBEDDING_NODES': en,
'HP_LSTM_NODES': ln,
'HP_TIMEDIST_NODES': td
}
run_name = "run-%d" % session_num
print('--- Starting trial: %s' % run_name)
print({h: hparams[h] for h in hparams})
log_dir = os.path.join('s3://sn-classification', args.type, 'Logs', args.country,
args.subfolder, 'HParams', log_date_subfolder)
run_hparam(log_dir, hparams, hparams_dict, args, n_classes, train_dataset,
val_dataset, tokenizer)
session_num += 1
def run_hparam(log_dir, hparams, hparams_dict, args, n_classes, train_dataset, val_dataset, tokenizer):
with tf.summary.create_file_writer(log_dir).as_default():
hp.hparams_config(
hparams=list(hparams_dict.values()),
metrics=[hp.Metric('val_top_k_categorical_accuracy', display_name='TopK_Val_Accuracy'),hp.Metric('val_loss', display_name='val_loss')]
)
# hp.hparams(hparams) # record the values used in this trial
hp.hparams({hparams_dict[h]: hparams[h] for h in hparams_dict.keys()})
history = train(args,n_classes,hparams,train_dataset, val_dataset, tokenizer)
tf.summary.scalar('val_top_k_categorical_accuracy', history['val_top_k_categorical_accuracy'][-1], step=1)
tf.summary.scalar('val_loss', history['val_loss'][-1], step=1)
我已经做了很多谷歌搜索,但我仍然不确定如何实施更有效的优化会话,例如贝叶斯优化,以便以更快的方式找到最佳模型。我想知道的是 - 是否可以在 HParams 中进行贝叶斯优化,或者我是否需要使用不同的包,如权重和偏差?如果可能的话,任何关于在哪里可以找到这种实现示例的建议都会非常有帮助。
解决方案
这是一个长期开放的功能请求,遗憾的是目前仍未在该HPARAMS
部分实现,但Keras-tuner
允许您记录每次运行的结果。将超参数值编码到这些目录名称中可能是一种快速而肮脏的解决方法。为了未来读者的利益,我在此答案的末尾提供了使用 TensorBoard 进行贝叶斯优化的指南。
我可能会补充一点,TensorBoard 可视化对于使用网格或随机搜索来告知开发人员的手动调整直觉很有用,但由于贝叶斯优化是一个独立的黑盒优化器,你应该能够让它运行而不影响优化本身由于缺乏可视化——尽管我同意这仍然是一个不错的功能。
为了在 TensorFlow 中实现贝叶斯优化并记录每次运行的损失,我为未来的读者提供以下内容:
首先定义一个 HyperParameters 对象hp
。
from kerastuner.engine.hyperparameters import HyperParameters
hp = HyperParameters()
编写一个model_builder
带参数的函数hp
,使用 将超参数合并到模型中hp.get('name')
。定义一个 Keras-tunerBayesianOptimization
调谐器。
import kerastuner as kt
tuner = kt.BayesianOptimization(model_builder,
hyperparameters = hp,
max_trials = 20,
objective = 'val_loss')
在您的回调中包含tf.keras.callbacks.TensorBoard(cb_dir)
在目录中记录 BaysianOptimiser 每次运行的损失图cb_dir
。这包括针对时代的标量图,但不包括该HPARAMS
部分。您可能希望命名这些运行文件,以便它们列出超参数。
tuner.search(inputs, prices,
validation_split = 0.2,
batch_size = 32,
callbacks = [tf.keras.callbacks.TensorBoard(cb_dir)],
epochs = 30)
n
通过以下方式访问得分最高的超参数组合的字典
ith_best_hp_dict = tuner.get_best_hyperparameters(num_trials = n)[i]
推荐阅读
- r - 回归结果在 r 中产生 NA 值
- c# - 类型系统怪异:Enumerable.Cast
() - terraform - Terraform如何遍历cloudinit用户数据块
- javascript - 使用 express 提供静态文件以进行反应
- html - 使用 React.js 将具有多列的行添加到 html 表
- excel - 如果满足条件,VBA 从工作表复制行并将其粘贴到不同的工作表中
- mysql - 两个 MySQL 表之间的关系问题
- dart - 无论如何要在 Flutter Webview 上绘制一个小部件
- dart - 如何在flutter redux中不构建小部件的情况下调度动作?
- r - 如何聚合不同日期的数据并考虑 R 中的其他列?