python - 为什么 saved_model_cli 有效而加载 saved_model.pb 无效?
问题描述
我今天收到了一个新的 Tensorflow 保存模型(saved_model.pb + variables/),我正在尝试用它来预测图像。
我对输入和输出一无所知,但我可以使用saved_model_cli show得到一些。
使用saved_model_cli,我能够预测一张输入图像。
saved_model_cli run --dir /path/to/SavedModels/mymodel --signature_def model_signature --tag_set serve --inputs 'input=/path/to/image/as/numpy/image.npy' --outdir=/tmp/out --overwrite
我正在尝试在 Python 代码中加载此模型,但无法获得相同的结果。这是我尝试过的代码:
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.platform import gfile
with gfile.FastGFile(model_filename, 'rb') as f:
data = compat.as_bytes(f.read())
sm = saved_model_pb2.SavedModel()
sm.ParseFromString(data)
if 1 != len(sm.meta_graphs):
print('More than one graph found. Not sure which to write')
sys.exit(1)
tf.compat.v1.import_graph_def(sm.meta_graphs[0].graph_def)
graph = tf.compat.v1.get_default_graph();
with graph.as_default():
session = tf.compat.v1.Session(graph=graph)
with session.as_default():
output_tensor = graph.get_tensor_by_name('import/output_tensorname:0')
image = load('/path/to/numpy/image/image-wxh.npy')
predictions = self.session.run(output_tensor,{'import/input:0': image})
print(predictions.shape)
当我运行此代码时,我会收到错误消息,例如
tensorflow.python.framework.errors_impl.FailedPreconditionError:发现 2 个根错误。(0) 失败的前提条件:从容器读取资源变量 conv2d_xx/kernel 时出错:localhost。这可能意味着该变量未初始化。未找到:容器 localhost 不存在。(找不到资源:localhost/conv2d_xx/kernel)
我错过了什么吗?
与 saved_model_cli 有什么区别?
解决方案
好吧,这是模型加载的问题。
import tensorflow as tf
import settings
import numpy as np
from PIL import Image
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.framework import convert_to_constants
def get_func_from_saved_model(saved_model_dir):
saved_model_loaded = tf.saved_model.load(
saved_model_dir, tags=[tag_constants.SERVING])
graph_func = saved_model_loaded.signatures['my_model_signature']
graph_func = convert_to_constants.convert_variables_to_constants_v2(graph_func)
return graph_func
export_path='/path/to/saved/model'
graph_func = get_func_from_saved_model(export_path)
path = '/path/to/image/image.png'
img = Image.open(path)
img.show()
img_data = np.array(img).reshape(1,224,224,1)
y_pred = graph_func(tf.convert_to_tensor(img_data))[0].numpy()
y_pred 包含与saved_model_cli脚本相同的输出图像。
推荐阅读
- ios - 无法转换类型“(结果
) -> Void' 到预期的参数类型 '(Result<_>) -> Void' - audio - 如何监听rtp数据包
- ms-access - 访问数据库字符串到日期转换问题
- node.js - 一起看 SASS/SCSS 和 nodemon
- php - laravel工厂中如何使用faker分别插入日期和时间
- java - Java 单元测试验证 ServletOutputStream
- javascript - THREE.js Orbit 控制缩放限制
- traefik - 某些入口点的 Traefik 全局前端规则
- facebook-marketing-api - 作为洞察 API 的细分,我们可以期待哪些可能的值
- android - 将 Retrofit2 响应传递给活动