首页 > 解决方案 > Tensorflow Estimator Hook 访问传递给 model_fn 的特征标签和在 model_fn 期间构建的图形操作

问题描述

我正在尝试了解使用 tensorflow Estimator 框架构建的模型。我想使用 Hooks API 添加在评估或预测期间处理输入的操作。

看来我应该能够model_fn在训练期间单独使用,并实现我自己的SessionRunHook类来添加操作,但是我如何获得模型的输入张量?例如,假设model_fn看起来像

def model_fn(features, labels, mode, params):
    concatanated_features = prepare_inputs(features, params)
    ...

prepare_inputs只是做类似的事情

def prepare_inputs(features, params):
    return tf.feature_column.input_layer(features, params['column_names'])`

然后我做类似的事情

class MyHook(tf.train.SessionRunHook):
    def begin(self):
        self.myTensor = my_function(features) # but how do I get features?
        self.myTensor2 = my_function(concatanated_features) # likewise,
        gr = tf.get_default_graph() # seems I have to start here and know what I'm looking for

标签: tensorflowtensorflow-estimator

解决方案


我建议你看看TF Serve预测。您可以使用提供的 gRPC/REST API 调用您保存的模型以获取预测。

而且,您可以在生成用于 REST 调用的 JSON 请求之前执行任何所需的预处理。这个来自 TF 的例子包括Servehttps ://www.tensorflow.org/tfx/tutorials/serving/rest_simple?hl=en


推荐阅读