tensorflow - 恢复模型 tf.estimator.DNNClassifier
问题描述
model = tf.estimator.DNNClassifier(feature_columns=feat_cols, hidden_units=[1024, 512, 256])
model.train(input_fn=input_func,steps=5000)
这创建检查点
我复出第 2 天;现在我需要检查点的模型;如何恢复?
sess=tf.Session()
saver = tf.train.import_meta_graph(file_path + "/" + "model.ckpt-1000.meta")
saver.restore(sess,tf.train.latest_checkpoint(file_path))
model = ????? -- how do I get my model back?
解决方案
不知道我为什么挣扎。回答比较简单。阅读关于检查点的文章:https ://www.tensorflow.org/get_started/checkpoints
重新加载模型的非常简单的代码:
model_load = tf.estimator.DNNClassifier(feature_columns=feat_cols, hidden_units=[10, 10, 10, 10], model_dir="C:/Users/AI101~1/AppData/Local/Temp/tmpm2ndcvf_")