tensorflow - Tensorflow Estimator Hook 访问传递给 model_fn 的特征标签和在 model_fn 期间构建的图形操作
问题描述
我正在尝试了解使用 tensorflow Estimator 框架构建的模型。我想使用 Hooks API 添加在评估或预测期间处理输入的操作。
看来我应该能够model_fn
在训练期间单独使用,并实现我自己的SessionRunHook
类来添加操作,但是我如何获得模型的输入张量?例如,假设model_fn
看起来像
def model_fn(features, labels, mode, params):
concatanated_features = prepare_inputs(features, params)
...
prepare_inputs
只是做类似的事情
def prepare_inputs(features, params):
return tf.feature_column.input_layer(features, params['column_names'])`
然后我做类似的事情
class MyHook(tf.train.SessionRunHook):
def begin(self):
self.myTensor = my_function(features) # but how do I get features?
self.myTensor2 = my_function(concatanated_features) # likewise,
gr = tf.get_default_graph() # seems I have to start here and know what I'm looking for
解决方案
我建议你看看TF Serve
预测。您可以使用提供的 gRPC/REST API 调用您保存的模型以获取预测。
而且,您可以在生成用于 REST 调用的 JSON 请求之前执行任何所需的预处理。这个来自 TF 的例子包括Serve
:
https ://www.tensorflow.org/tfx/tutorials/serving/rest_simple?hl=en
推荐阅读
- go - 使用升级版 protoc-gen-go 和 protoc 编译器重新生成后 protobuf 中的新字段
- python-3.x - 如何合并这个字典列表?
- python - 查询从 numpy 数组中取值
- html - Edge 中带有不必要的垂直滚动条的页面
- java - swt 菜单上单击隐藏显示复合
- azure-resource-manager - 429错误时如何在Azure帐户中增加读取次数请求限制
- sql - 在包含顶级 UNION INTERSECT 或 EXCEPT 运算符的语句中不允许变量赋值
- pandas-groupby - pandas groupby 提取前百分比 n 数据(降序)
- android - 在颤振应用程序中获得互联网许可
- tcp - 尝试通过 AT 命令连接到 4G GSM 模块时 TCP/IP 连接被拒绝