python - 使用 TF1 读取使用 TF2 创建的 protobuf
问题描述
我有一个存储为 hdf5 的模型,我使用 saved_model.save 将其导出到 protobuf (PB) 文件,如下所示:
from tensorflow import keras
import tensorflow as tf
model = keras.models.load_model("model.hdf5")
tf.saved_model.save(model, './output_dir/')
这工作正常,结果是一个 saved_model.pb 文件,我以后可以用其他软件查看它,没有问题。
但是,当我尝试使用 TensorFlow1 导入此 PB 文件时,我的代码失败了。由于 PB 应该是一种通用格式,这让我感到困惑。
我用来读取 PB 文件的代码是这样的:
import tensorflow as tf
curr_graph = tf.Graph()
curr_sess = tf.InteractiveSession(graph=curr_graph)
f = tf.gfile.GFile('model.hdf5','rb')
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
f.close()
这是我得到的例外:
回溯(最近一次调用):文件“read_pb.py”,第 14 行,在 graph_def.ParseFromString(f.read()) google.protobuf.message.DecodeError:解析消息时出错
我有一个存储为 PB 文件的不同模型,在该文件上读取代码可以正常工作。
这是怎么回事?
***** 编辑 1 *****
在下面使用 Andrea Angeli 的代码时,我遇到了以下错误:
遇到错误:NodeDef 在 Op y:T、batch_mean:U、batch_variance:U、reserve_space_1:U、reserve_space_2:U、reserve_space_3:U 中没有提到 attr 'exponential_avg_factor';attr=T:type,allowed=[DT_HALF, DT_BFLOAT16, DT_FLOAT]; attr=U:type,allowed=[DT_FLOAT]; attr=epsilon:float,default=0.0001; attr=data_format:string,default="NHWC",allowed=["NHWC", "NCHW"]; attr=is_training:bool,default=true>; NodeDef:{node u-mobilenetv2/bn_Conv1/FusedBatchNormV3}。(检查您的 GraphDef 解释二进制文件是否与您的 GraphDef 生成二进制文件是最新的。)。
有解决方法吗?
解决方案
您正在尝试读取 hdf5 文件,而不是您保存的 protobuf 文件tf.saved_model.save(..)
。还要注意,TF2 导出的 protobuf 与 TF 1 的冻结图不同,因为它只包含计算图。
编辑 1: 如果要从 TF 2 模型导出 TF 1 样式的冻结图,可以使用以下代码片段完成:
from tensorflow.python.framework import convert_to_constants
def export_to_frozen_pb(model: tf.keras.models.Model, path: str) -> None:
"""
Creates a frozen graph from a keras model.
Turns the weights of a model into constants and saves the resulting graph into a protobuf file.
Args:
model: tf.keras.Model to convert into a frozen graph
path: Path to save the profobuf file
"""
inference_func = tf.function(lambda input: model(input))
concrete_func = inference_func.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
output_func = convert_to_constants.convert_variables_to_constants_v2(concrete_func)
graph_def = output_func.graph.as_graph_def()
graph_def.node[-1].name = 'output'
with open(os.path.join(path, 'saved_model.pb'), 'wb') as freezed_pb:
freezed_pb.write(graph_def.SerializeToString())
这将在您在path
param 中指定的位置生成一个 protobuf 文件 (saved_model.pb)。您的图形的输入节点将具有名称“input:0”(这是由 lambda 实现的)和输出节点“output:0”。
推荐阅读
- graphql - graphql join-monster 与strapi io的兼容性
- php - PHP 为什么显式类型转换 + 1 可以工作,但不能使用增量运算符
- assembly - 对齐给定指令,但将对齐填充放置在指令之前以外的位置
- swift - 在 swift 4 中迭代 3 度嵌套字典
- firebase - ReferenceError:使用 onAuthStateChanged 时未定义 firebase 错误
- sql - 如果我的 where 条件使用 Hive 给出空输出,如何显示表中的所有记录
- xml - 如何解决 Strings.xml 中高棉字符集的特定无效字符?
- git - 在 Windows 上删除 Git 的缓存密码
- javascript - Angularjs - 在 ng-repeat 中 Md-select 中断
- javascript - Ant design - 有没有办法获得我实际使用的样式?