python - TF 会话图为空 - 从两个导入的图创建新的 tf 图
问题描述
现在,我正在使用两个模型,其中一个从检查点加载,另一个加载冻结模型。在一个类中,我有一个 loadmodel 函数和另一个处理函数,我想将这两个模型应用到一个图像上,以便将来的程序只需调用该类的该方法就可以轻松地处理图像上的所有内容。
这是代码的简化示例,仅适用于我在课堂上苦苦挣扎的方法:
def segmented_stylization(self, img, target_size = 3840):
self.g1 = self.seg_graph.as_graph_def()
self.g2 = self.eval_graph.as_graph_def()
start = time.time()
# Here we need to create a combined graph
with tf.Graph().as_default() as g_combined:
# Get both our stylized frame and segmentation mappings by passing in the original frame input
style_input = self.eval_graph.get_tensor_by_name("img_placeholder:0")
style_output = self.eval_graph.get_tensor_by_name("style_outputs:0")
seg_input = self.seg_graph.get_tensor_by_name("ImageTensor:0")
seg_output = self.seg_graph.get_tensor_by_name("SemanticPredictions:0")
seg_map = seg_output[0]
# Take our segmented mapping and identify the pixels corresponding to any human(s) in the picture
hPixels = tf.where(tf.equal(seg_map, self.humanIdx))
# Replace the pixels corresponding to a human
#mframe = tf.scatter_nd_update(style_output, hPixels, tf.gather_nd(resized_img, hPixels))
# mframe = tf.image.resize_nearest_neighbor(style_output (self.output_width, self.output_height))
#mframe = tf.clip_by_value(mframe, 0, 255)
sess_comb = tf.Session(graph = g_combined)
a = sess_comb.run([hPixels], feed_dict = {"img_placeholder:0" : resized_img, "ImageTensor:0" : resized_img})
output_img = sess_comb.run([mframe], feed_dict = {"img_placeholder:0" : resized_img, "ImageTensor:0" : resized_img})
return output_img
我上面得到的错误是
raise RuntimeError('The Session graph is empty. Add operations to the '
RuntimeError: The Session graph is empty. Add operations to the graph before calling run().
我还尝试在 with 语句中包装组合会话,尝试嵌套和外部,但还没有让它工作。
我尝试过的另一件事是代替 get_tensor_by_name:
self.g1 = self.seg_graph.as_graph_def()
self.g2 = self.eval_graph.as_graph_def()
seg_output, = tf.import_graph_def(self.g1, input_map = {"ImageTensor:0" : resized_img, "img_placeholder:0" : resized_img}, return_elements = ["SemanticPredictions:0"])
style_output, = tf.import_graph_def(self.g2, input_map={"img_placeholder:0" : resized_img}, return_elements=["style_outputs:0"])
但这会返回错误:
ValueError: Input 0 of node import/Shape was passed double from import/_inputs/Const:0 incompatible with expected uint8.
任何见解将不胜感激,谢谢!对 tf 来说还是新手 :)。
解决方案
推荐阅读
- python - 如何解决这个最大递归深度错误?
- node.js - 为不和谐服务器制作一个简单的音乐机器人,不起作用
- ruby-on-rails - 代理上的 Rails HTTP 请求中继器并返回响应
- jenkins - 如何在声明性管道中使用矩阵部分
- python - 尝试使用 Pandas Lambda 更改行值时出现语法错误
- excel - 使用excel VBA从字符串数组中删除重复项
- vue.js - 注销未从本地存储或状态清除令牌
- git - 如何禁用git删除分支的能力?
- django - 基于 CharField 用 ForeignKey 替换 CharField
- java - 如何更改 barChart ApachePoi 中文本的颜色