tensorflow - 在 python 中加载 tf.saved_model 后如何获取输入和输出张量
问题描述
假设我用以下代码保存了一个模型
tf.saved_model.simple_save(sess, export_dir, in={'input_x': x, 'input_y':y}, out={'output_z':z})
现在我将保存的模型加载到另一个 python 程序中
with tf.Session() as sess:
tf.saved_model.loader.load(sess, ['serve'], export_dir)
现在的问题是如何通过调用 simple_save() 方法时在输入/输出参数中指定的 'input_x'、'input_y'、'output_z' 键获取 x、y、z 张量的句柄?
我在网上找到的唯一解决方案是在创建 x、y、z 张量时显式命名它们,然后使用这些名称从图中检索它们,这似乎是相当多余的,因为我们在调用 simple_save() 时为它们指定了键.
解决方案
我确实遇到了您的问题,经过一些调查(我认为 TF 文档很差),我找到了下一个解决方案:
使用返回的 MetaGraphDef 对象来查找您的输入\输出名称映射。
graph = tf.Graph()
with graph.as_default():
metagraph = tf.saved_model.loader.load(sess, [tag_constants.SERVING],save_path)
inputs_mapping = dict(metagraph.signature_def['serving_default'].inputs)
outputs_mapping = dict(metagraph.signature_def['serving_default'].outputs)
此代码将为您提供保存到“TensorInfo”对象时提供的名称之间的映射,您可以从他那里轻松获取映射的张量名称,例如:
my_input = inputs_mapping['my_input_name'].name
my_input_t = graph.get_tensor_by_name(my_input)
推荐阅读
- node.js - 函数在 Firebase 云节点函数中返回未定义、预期的 Promise 或值
- java - 弹簧控制器没有得到扫描
- azure - 当我尝试创建新保管库并提供对特定应用程序的保管库访问权限时,在 Azure Vault 中通过 rest api 未提供访问权限?
- python - 如何使用 Scrapy 定位数据属性
- xml - 转换器是否可以忽略 XML 标记错误?
- mongodb - MongoDb 使用 ref 查询其他集合
- html - 该部分的背景图像到特定高度
- selenium-webdriver - IE11 Save As Popup 处理问题 WebDriver/Autoit/Robot
- python-3.x - 将无限列表解析为一个列表
- php - 缺少作者;谷歌网站管理员工具中缺少更新