python - 使用 TensorBoard 上的 Estimator 测试错误和准确度图
问题描述
我对 TensorFlow 比较陌生,对 TensorBoard 和 Estimator API 来说绝对是新手。我想从我提供的源代码中训练一个稍微修改过的 Tensorflow Resnet 模型,并作为官方模型使用 Estimator。
我需要测试误差和准确度图。精度图只显示一个点,我根本无法得到测试错误。我还需要 x 轴上的时期数,而不是步数。使用低级 Tensorflow 实现这一点更简单,但是我需要使用给定的模型。
我在 resnet_fn 中创建了 learning_rate、cross_entropy 和 train_accuracy 张量,如下所示。我还在 resnet_fn 中添加了 SummarySaverHook。它也无济于事。
def resnet_model_fn(...):
...
tf.identity(learning_rate, name='learning_rate')
tf.summary.scalar('learning_rate', learning_rate)
....
summary_hook = tf.train.SummarySaverHook(
flags.epochs_per_eval,
output_dir=FLAGS.model_dir,
summary_op=tf.summary.merge_all())
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
eval_metric_ops=metrics,
training_hooks=[summary_hook],
evaluation_hooks=[summary_hook])
这是 resnet_main()。我可以在我的终端上看到那些张量“eval_cross_entropy”等,但是它们根本没有显示在 TensorBoard 中。我也在分享 TensorBoard 的屏幕截图。
def resnet_main(flags, model_function, input_function):
# Using the Winograd non-fused algorithms provides a small performance boost.
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
# Set up a RunConfig to only save checkpoints once per training cycle.
run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9, save_summary_steps=flags.epochs_per_eval)
classifier = tf.estimator.Estimator(
model_fn=model_function, model_dir=flags.model_dir, config=run_config,
params={
'resnet_size': flags.resnet_size,
'data_format': flags.data_format,
'batch_size': flags.batch_size,
})
for _ in range(flags.train_epochs // flags.epochs_per_eval):
tensors_to_log = {
'learning_rate': 'learning_rate',
'cross_entropy': 'cross_entropy',
'train_accuracy': 'train_accuracy'
}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=flags.epochs_per_eval)
print('Starting a training cycle.')
def input_fn_train():
return input_function(True, flags.data_dir, flags.batch_size,
flags.epochs_per_eval, flags.num_parallel_calls)
classifier.train(input_fn=input_fn_train, hooks=[logging_hook])
tensors_to_log_eval = {
'eval_cross_entropy': 'cross_entropy',
'eval_train_accuracy': 'train_accuracy'
}
logging_hook_eval = tf.train.LoggingTensorHook(
tensors=tensors_to_log_eval, every_n_iter=flags.epochs_per_eval)
print('Starting to evaluate.')
# Evaluate the model and print results
def input_fn_eval():
return input_function(False, flags.data_dir, flags.batch_size,
1, flags.num_parallel_calls)
eval_results = classifier.evaluate(input_fn=input_fn_eval, hooks=[logging_hook_eval])
tensors_to_log_pred = {
'pred_cross_entropy': 'cross_entropy'
}
logging_hook_pred = tf.train.LoggingTensorHook(
tensors=tensors_to_log_pred, every_n_iter=flags.epochs_per_eval)
print('Starting to predict.')
def input_fn_pred():
return input_function(False, flags.data_dir, flags.batch_size,
1, flags.num_parallel_calls)
pred_results = classifier.predict(input_fn=input_fn_pred, hooks=[logging_hook_pred])
return eval_results, pred_results
如何在 x 轴上获得带有时代数的测试错误和准确度图?
解决方案
推荐阅读
- javascript - 在 React 中动态渲染外部 HTML/React 组件
- ios - 以编程方式添加的 UITableView 中的错误单元格插图
- jquery - 未捕获的 TypeError:尝试使用 jQuery 时的非法构造函数
- javascript - 如何以箭头符号解构参数
- webpack - 如何在 vue.js 应用程序的生产环境中禁用源映射?
- php - Laravel 和 Mysql '已经消失了'
- clearcase - 无法识别的命令:“vob_restore”
- .htaccess - .htaccess 重写规则以包含文件
- f# - 如何从 IList 对象中的单个对象创建对象列表
- java - 如何使用 Eclipse 工作区作业等详细信息按钮实现 Eclipse 进度条