首页 > 解决方案 > 从保存的 DNN 估计器进行预测,Predict_Dict 无效

问题描述

我在存储和加载经过训练的模型时遇到问题。我正在导出我的模型 feature_spec = { 'BUY': parsing_ops.FixedLenFeature(shape=(1,), dtype=tf.float32, default_value=1), 'ASK': parsing_ops.FixedLenFeature(shape=(1,), dtype=tf.float32, default_value=2), 'DIFF': parsing_ops.FixedLenFeature(shape=(1,), dtype=tf.float32, default_value=3) } classifier.export_savedmodel('Y:\Checkers\Model1', tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec))

并且这部分代码运行顺利,我只是在我的导出出现偏差的情况下将其包括在内。然后我尝试将此模型加载到另一个文件中

BUY = float(PriceState.bidVol)
ASK = float(PriceState.askVol)
DIFF = float(PriceState.bidVol) - float(PriceState.askVol)
predict_dict = {
    "inputs": [BUY, ASK, DIFF]
}
predict_fn = predictor.from_saved_model('Y:/Checkers/Model1/1529118977')
predictions = predict_fn(predict_dict)

但是,在这里我得到了错误

" 文件 "./GDAXfinal.py",第 40 行,在 GetPrediction 预测 = predict_fn(predict_dict) 文件 "C:\Users\Andy\Python\lib\site-packages\tensorflow\contrib\predictor\predictor.py",行77,待命 返回 self._session.run(fetches=self.fetch_tensors, feed_dict=feed_dict) 文件“C:\Users\Andy\Python\lib\site-packages\tensorflow\python\client\session.py”,第 900 行,运行中run_metadata_ptr)文件“C:\Users\Andy\Python\lib\site-packages\tensorflow\python\client\session.py”,第 1135 行,_run feed_dict_tensor,选项,run_metadata)文件“C:\Users\Andy\ Python\lib\site-packages\tensorflow\python\client\session.py”,第 1316 行,在 _do_run run_metadata 中)文件“C:\Users\Andy\Python\lib\site-packages\tensorflow\python\client\session .py”,第 1335 行,在 _do_call raise type(e)(node_def, op, message) tensorflow.python.framework.errors_impl.InternalError: 无法将元素作为字节获取。 "

我尝试通过首先将浮点数转换为字符串,然后转换为字节(bytes(str(<myVal>), encoding="utf-8"))来将输入设置为字节

然后我收到以下错误

"

InvalidArgumentError(参见上面的回溯):无法解析示例输入,值:'0.0' [[节点:ParseExample/ParseExample = ParseExample[Ndense=3,Nsparse=0,Tdense=[DT_FLOAT,DT_FLOAT,DT_FLOAT],dense_shapes=[ [1], [1], [1]], sparse_types=[], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_input_example_tensor_0_0, ParseExample/ParseExample/names, ParseExample/ParseExample/dense_keys_0, ParseExample/ParseExample/dense_keys_1, ParseExample/ParseExample/dense_keys_2, ParseExample/Reshape, ParseExample/Reshape_1, ParseExample/Reshape_2)]]

"

我猜这是因为输入是浮点数,但它们被作为字符串接收。但是,由于某种原因,如果我设置

"inputs": [b'0.', b'0.', b'0.']

那么预测就会顺利通过。

但是任何轻微的变化都会引发 float / str 错误。

例子:

"inputs": [b'0.0', b'0.0', b'0.0']
 "inputs": [b'10.', b'10.', b'10.']

抛出错误。

我已经尝试遵循尽可能多的在线教程来导出和加载 tensorflow 预制估计器,但似乎没有一个遇到这个问题。

标签: python-3.xtensorflow

解决方案


build_parsing_serving_input_receiver_fn指解析序列化的tf.train.Example协议缓冲区。要按原样使用 SavedModel,您需要构造和序列化 atf.train.Example并将其作为单个字符串提供。

还有一个tf.estimator.export.build_raw_serving_input_receiver_fn可以让您分别提供特征字典的每个元素。这应该适用于您的第一个示例(将可转换的内容传递给浮点 Numpy 数组)。


推荐阅读