tensorflow - 从保存的 TensorFlow RandomForest 分类器中加载操作
问题描述
我已经训练了一个类似于以下代码的 TF 随机森林分类器:
X = tf.placeholder(tf.float32, shape=[None, num_features])
Y = tf.placeholder(tf.int32, shape=[None])
hparams = tensor_forest.ForestHParams(num_classes=num_classes,
num_features=num_features,
num_trees=num_trees).fill()
forest_graph = tensor_forest.RandomForestGraphs(hparams)
train_op = forest_graph.training_graph(X, Y)
loss_op = forest_graph.training_loss(X, Y)
infer_op, _, _ = forest_graph.inference_graph(X)
correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y,tf.int64))
accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
init_vars = tf.group(tf.global_variables_initializer(),
resources.initialize_resources(resources.shared_resources()))
with tf.Session() as sess:
sess.run(init_vars)
saver = tf.train.Saver()
for i in range(1, 100):
for batch_x, batch_y in render_batch(batch_size):
_, l = sess.run([train_op, loss_op], feed_dict={X: batch_x, Y: batch_y})
acc = sess.run(accuracy_op, feed_dict={X: batch_x, Y: batch_y})
print('Step %i, Loss: %f, Acc: %f' % (i, l, acc))
if acc >= 0.87:
print("Stopping and saving")
save_path = saver.save(sess, model_path)
print("Model saved in file: %s" % save_path)
break
现在我想重新加载我的模型并使用它来对看不见的数据进行预测,如下所示:
with graph.as_default():
session_conf = tf.ConfigProto()
sess = tf.Session(config = session_conf)
with sess.as_default():
saver = tf.train.import_meta_graph("{}.meta".format(model_path))
saver.restore(sess,checkpoint_file)
accuracy_op = graph.get_operation_by_name("accuracy_op").outputs[0]
print(sess.run(accuracy_op, feed_dict={X: x_test, Y: y_test}))
但是,我收到以下错误消息:
KeyError: "The name 'accuracy_op' refers to an Operation not in the graph."
我的问题是 - 我如何保存我的模型,以便在重新加载它时,我可以导入上面定义的那些操作并将它们用于看不见的数据?
谢谢!
解决方案
由于您使用的是get_operation_by_name
,因此您应该将 op 命名为accuracy_op
。你可以通过使用来做到这一点tf.identity
:
accuracy_op = tf.identity(tf.reduce_mean(tf.cast(correct_prediction, tf.float32)), 'accuracy_op')
我看到您正在使用张量X
并且Y
没有从新图表中加载。所以在原始代码中命名张量,然后使用get_tensor_by_name()
推荐阅读
- python - 使用 urllib 更改 URL 的主机名
- javascript - 将蛇形大小写字符串转换为标题大小写
- python - 我在哪里放置来自 github 的 python 存储库文件?
- c - 在 C 中动态包含源文件
- conditional-statements - 为什么即使将条件显式设置为 true,我的 doIf 代码块也不执行?
- python - 尝试写入另一个程序使用的文件时如何避免 OSErrors?
- c# - 在字符串中搜索超链接
- batch-file - 存储可选管道多行字符串和可选参数的批处理脚本
- python - 以管理员身份运行 Anaconda macOS Catalina
- java - targetCompatability 和使用该 JDK 运行 Gradle 之间有区别吗?