首页 > 解决方案 > Tensorflow:只有在训练期间将错误最小化,我才能保存检查点?

问题描述

我正在运行一个 tensorflow 程序,我想存储最好的模型以备后用。我正在使用估计器tf.contrib.tpu.TPUEstimator接受 run_config 参数的模块,我在其中设置save_checkpoints_secs=20*60)进行训练。

estimator.train 将 train_input_fn 和 num_train_steps 作为参数。例如: estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

我不想在每“n”秒后保存检查点,而是想存储训练时误差最小的最佳模型。

欢迎任何帮助。

标签: tensorflow

解决方案


tf.estimator.BestExporter 似乎正是您正在寻找的。根据文档,它指出:

每当新模型优于任何现有模型时,此类都会执行模型导出。

  estimator = tf.estimator.DNNClassifier(
      config=tf.estimator.RunConfig(
          model_dir='/my_model', save_summary_steps=100),
      feature_columns=[categorial_feature_a_emb, ...],
      hidden_units=[1024, 512, 256])

  serving_feature_spec = tf.feature_column.make_parse_example_spec(
      categorial_feature_a_emb)
  serving_input_receiver_fn = (
      tf.estimator.export.build_parsing_serving_input_receiver_fn(
      serving_feature_spec))

  exporter = tf.estimator.BestExporter(
      name="best_exporter",
      serving_input_receiver_fn=serving_input_receiver_fn,
      exports_to_keep=5)

  train_spec = tf.estimator.TrainSpec(...)

  eval_spec = [tf.estimator.EvalSpec(
    input_fn=eval_input_fn,
    steps=100,
    exporters=exporter,
    start_delay_secs=0,
    throttle_secs=5)]

推荐阅读