首页 > 解决方案 > 恢复模型 tf.estimator.DNNClassifier

问题描述

model = tf.estimator.DNNClassifier(feature_columns=feat_cols, hidden_units=[1024, 512, 256])
model.train(input_fn=input_func,steps=5000)

这创建检查点

我复出第 2 天;现在我需要检查点的模型;如何恢复?

sess=tf.Session()
saver = tf.train.import_meta_graph(file_path + "/" + "model.ckpt-1000.meta")
saver.restore(sess,tf.train.latest_checkpoint(file_path))
model = ????? -- how do I get my model back?

标签: tensorflow

解决方案


不知道我为什么挣扎。回答比较简单。阅读关于检查点的文章:https ://www.tensorflow.org/get_started/checkpoints

重新加载模型的非常简单的代码:

model_load = tf.estimator.DNNClassifier(feature_columns=feat_cols, hidden_units=[10, 10, 10, 10], model_dir="C:/Users/AI101~1/AppData/Local/Temp/tmpm2ndcvf_")

推荐阅读