python - keras.estimator.model_to_estimator - 无法热启动或加载先前的检查点
问题描述
我正在使用 keras model_to_estimator 函数训练张量流模型,然后使用训练数据进行训练。这很好用,然后我可以继续使用测试数据成功预测。
在一个单独的例程中,我希望能够使用最新的训练检查点加载预训练的估计器并进行更多的预测(即无需重新训练)。我已经看过了,warm_start_from
但这在加载 keras 模型时似乎不可用。我对https://www.tensorflow.org/get_started/checkpoints的理解是,我可以从同一个 keras 模型创建一个新的估计器,并且第一次预测它会从我指定的目录中加载检查点。
以下代码片段是我尝试执行此操作(最终 estimator_model2 将在单独的例程中加载,这只是为了演示)。
modelConfig = tf.estimator.RunConfig('/myCheckpointpath', keep_checkpoint_max=1)
estimator_model = keras.estimator.model_to_estimator(keras_model=myKerasModel(inputShape, nOutputs), config=modelConfig)
estimator_model.train(input_fn=lambda: input_fn(_trainData_2d, _trainLabels, batch_size=self.batchSize, shuffle=True, num_epochs=2))
estimator_model2 = keras.estimator.model_to_estimator(keras_model=myKerasModel(inputShape, nOutputs), config=modelConfig)
predictions = list(estimator_model2.predict(input_fn=lambda: input_fn(_testData_2d)))
从诊断中我可以看到它在执行最后一行时尝试加载检查点。但是,我收到一个错误,表明训练期间保存的检查点不包含新估计器所需的所有信息。这是错误:
E NotFoundError (see above for traceback): Key conv2d_2/bias not found in checkpoint
E [[Node: save/RestoreV2 = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT64, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
如果有帮助,我可以展示 keras 模型,但我认为这不是问题。
谁能给我一个解决方案或建议一种更好的方法来加载具有先前训练过的值的估计器来进行预测?
解决方案
我对上述问题的解决方案是使用混合方法,其中我使用 keras 表示法指定模型,然后将其放入 tensorflow 模型函数中,然后可以将其加载到估计器中。通过采用这种方法,我可以像使用任何其他 tensorflow 模型一样保存到检查点并从中重新加载。我认为这提供了使用直观 keras 表示法的最佳组合,同时能够利用 tensorflow 估计器和数据工具。以下是我的方法的概述,描述了各种 tensorflow 调用的设置:
创建一个估算器:
|--estimator: tf.estimator.Estimator |--config: tf.estimator.RunConfig #checkpointPath and saving spec for training |--model_fn: tf.estimator.EstimatorSpec |--myKerasModel #specify model. Doesn't have to be keras. |--keras.models.Model |--loss: myLossFunction #train_and_eval only |--optimizer: myOptimizerFunction #train_and_eval only |--training_hooks:tf.train.SummarySaverHook #train_and_eval only - for saved diagnostics |--evaluation_hooks:tf.train.SummarySaverHook #train_and_eval only - for saved diagnostics |--predictions: model(data, training=False) #predict only
训练和评估:
|--tf.estimator.train_and_evaluate |--estimator: tf.estimator.Estimator #the estimator we created |--train_spec: tf.estimator.TrainSpec |--input_fn: tf.data.Dataset #specify the input function |--features, labels #data to use |--batch, shuffle, num_epochs #controls for data |--eval_spec: tf.estimator.EvalSpec |--input_fn: tf.data.Dataset #specify the input function |--features, labels #data to use |--batch #controls for data |--throttle and start_delay #specify when to start evaluation
请注意,我们可以使用tf.estimator.Estimator.train
andtf.estimator.Estimator.evaluate
但不允许在训练期间进行评估,因此我们使用tf.estimator.train_and_evaluate
。
预测:
|--estimator: tf.estimator.Estimator.predict #the estimator we created |--input_fn: tf.data.Dataset #specify the input function |--features #data to use
推荐阅读
- android - 吸气剂可见性未定义
- design-patterns - 了解如何在分布式系统中立即返回错误以进行支付
- python - 是否有一些反向直接方法可以将字典转换为数据框?
- java - jvm GC后操作系统内存会发生变化吗?
- python - Flask - 从 HTML 页面获取输入并将输入传递给另一个 Python 文件中的函数
- android - 推送通知操作按钮不显示 - 背景和终止状态 - React Native
- csv - 如何使用 javascript apify puppeteer 从谷歌趋势下载 Csv
- javascript - 在资源数组中找到正确的动态路径资源的最佳方法是什么?
- haskell - 如何将 Pair 定义为 Monoid?
- c# - Unity 播放器控制器脚本不起作用