首页 > 解决方案 > 如何使用经过训练和存储的张量流模型进行预测

问题描述

我有一个现有的训练模型(特别是 tensorflow word2vec https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/udacity/5_word2vec.ipynb)。我很好地恢复了现有模型:

model1 = tf.train.import_meta_graph("models/model.meta")
model1.restore(sess, tf.train.latest_checkpoint("model/"))

但我不知道如何使用新加载的(和训练好的)模型进行预测。如何使用恢复的模型进行预测?

编辑:

来自官方 tensorflow 存储库的模型代码https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/word2vec/word2vec_basic.py

标签: pythontensorflowword2vec

解决方案


根据您加载检查点的方式,我认为这应该是使用它进行推理的最佳方式。

加载占位符:

input = tf.get_default_graph().get_tensor_by_name("Placeholders/placeholder_name:0")
....

加载用于执行预测的操作:

prediction = tf.get_default_graph().get_tensor_by_name("SomewhereInsideGraph/prediction_op_name:0")

创建一个会话,执行预测操作,并在占位符中提供数据。

sess = tf.Session()
sess.run(prediction, feed_dict={input:input_data})

另一方面,我更喜欢做的是始终在类的构造函数中创建整个模型。然后,我要做的是:

tf.reset_default_graph()
model = ModelClass()
loader = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
loader.restore(sess, path_to_checkpoint_dir)

由于您想将训练过的 word2vec 模型中的嵌入加载到另一个模型中,因此您应该执行以下操作:

embeddings_new_model = tf.Variable(...,name="embeddings")
embedding_saver = tf.train.Saver({"embeddings_word2vec": embeddings_new_model})
with tf.Session() as sess:
    embedding_saver.restore(sess, "word2vec_model_path")

假设 word2vec 模型中的 embeddings 变量命名为embeddings_word2vec


推荐阅读