首页 > 技术文章 > Tensorflow模型的格式

ying-chease 2019-03-19 18:03 原文

参考转载:
https://cloud.tencent.com/developer/article/1009979
https://blog.csdn.net/qq_27825451/article/details/105866464

tensorflow模型的格式通常支持多种,主要有CheckPoint(*.ckpt)、GraphDef(*.pb)、SavedModel。

 

1. CheckPoint(*.ckpt)

在训练 TensorFlow 模型时,每迭代若干轮需要保存一次权值到磁盘,称为“checkpoint”,如下图所示:

这种格式文件是由 tf.train.Saver() 对象调用 saver.save() 生成的,只包含若干 Variables 对象序列化后的数据,不包含图结构,所以只给 checkpoint 模型不提供代码是无法重新构建计算图的。

载入 checkpoint 时,调用 saver.restore(session, checkpoint_path)。

缺点:首先模型文件是依赖 TensorFlow 的,只能在其框架下使用;其次,在恢复模型之前还需要再定义一遍网络结构,然后才能把变量的值恢复到网络中。

 

2. GraphDef(*.pb)

这种格式文件包含 protobuf 对象序列化后的数据,包含了计算图,可以从中得到所有运算符(operators)的细节,也包含张量(tensors)和 Variables 定义,但不包含 Variable 的值,因此只能从中恢复计算图,但一些训练的权值仍需要从 checkpoint 中恢复。下面代码实现了利用 *.pb 文件构建计算图:

TensorFlow 一些例程中用到 *.pb 文件作为预训练模型,这和上面 GraphDef 格式稍有不同,属于冻结(Frozen)后的 GraphDef 文件,简称 FrozenGraphDef 格式。这种文件格式不包含 Variables 节点。将 GraphDef 中所有 Variable 节点转换为常量(其值从 checkpoint 获取),就变为 FrozenGraphDef 格式。代码可以参考 tensorflow/python/tools/freeze_graph.py

*.pb 为二进制文件,实际上 protobuf 也支持文本格式(*.pbtxt),但包含权值时文本格式会占用大量磁盘空间,一般不用。

 

3. SavedModel

https://juejin.im/post/5bbfedd65188255c9b13d964

https://zhuanlan.zhihu.com/p/31417693

这是谷歌推荐的模型保存方式,它具有语言独立性,可独立运行,封闭的序列化格式,任何语言都可以解析它,它允许其他语言和深度学习框架读取、继续训练和迁移 TensorFlow 的模型。该格式为 GraphDef 和 CheckPoint 的结合体,另外还有标记模型输入和输出参数的 SignatureDef。从 SavedModel 中可以提取 GraphDef 和 CheckPoint 对象。

SavedModel 目录结构如下:

其中 saved_model.pb(或 saved_model.pbtxt)包含使用 MetaGraphDef protobuf 对象定义的计算图;assets 包含附加文件;variables 目录包含 tf.train.Saver() 对象调用 save() API 生成的文件。

以下代码实现了保存 SavedModel:

方法1:

#在模型创建并保存
#1.1 在model中
创建signature def signature_def(self): inputs = {'char_inputs': tf.saved_model.utils.build_tensor_info(self.char_inputs), 'seg_inputs': tf.saved_model.utils.build_tensor_info(self.seg_inputs), 'dropout': tf.saved_model.utils.build_tensor_info(self.dropout)} outputs = {'decode_tags': tf.saved_model.utils.build_tensor_info(self.decode_tags)} return tf.saved_model.signature_def_utils.build_signature_def(inputs=inputs ,outputs=outputs ,method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)

#1.2 保存模型
def save_model(self, sess, signature, save_path):
        builder = tf.saved_model.builder.SavedModelBuilder(save_path)
        builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING], {'predict': signature}, clear_devices=True)
        builder.save()


#1.3 保存模型中的signature参数使用get_signature方法创建
def get_signature(model):
signature = predict_signature_def(inputs={
'char_inputs': model.char_inputs,
'seg_inputs': model.seg_inputs,
'dropout': model.dropout},
outputs={'decode_tags': model.decode_tags}
)
return signature

 

方法2:

#在模型创建并保存
#2.1 在model中创建signature
def signature_def(self):
inputs = {'char_inputs': tf.saved_model.utils.build_tensor_info(self.char_inputs)
, 'seg_inputs': tf.saved_model.utils.build_tensor_info(self.seg_inputs)
, 'dropout': tf.saved_model.utils.build_tensor_info(self.dropout)}

outputs = {'decode_tags': tf.saved_model.utils.build_tensor_info(self.decode_tags)}
return tf.saved_model.signature_def_utils.build_signature_def(inputs=inputs
, outputs=outputs
,method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
#1.2 保存模型
def save_model(self, sess, signature, save_path):
builder = tf.saved_model.builder.SavedModelBuilder(save_path)
builder.add_meta_graph_and_variables(sess=sess
, tags=[tf.saved_model.tag_constants.SERVING]
, signature_def_map=signature
, clear_devices=True)
builder.save()
 
#1.3 保存模型中的signature参数使用get_signature方法创建
def get_signature(model):
    inputs = {
'char_inputs': model.char_inputs,
'seg_inputs': model.seg_inputs,
'dropout': model.dropout}
outputs = {'decode_tags': model.decode_tags}

signature = {
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
tf.saved_model.signature_def_utils.predict_signature_def(inputs, outputs)
}

return signature

4、模型载入
# encoding = utf8

import tensorflow as tf
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)

model_path = 'xxxx'
meta_graph = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], model_path)

print('load succeed!')

  char_inputs_ = signature['predict'].inputs['char_inputs'].name
  seg_inputs_ = signature['predict'].inputs['seg_inputs'].name
  dropout_ = signature['predict'].inputs['dropout'].name
  decode_tags_ = signature['predict'].outputs['decode_tags'].name
  # get tensor   char_inputs = sess.graph.get_tensor_by_name(char_inputs_)   seg_inputs = sess.graph.get_tensor_by_name(seg_inputs_)   dropout = sess.graph.get_tensor_by_name(dropout_)   decode_tags = sess.graph.get_tensor_by_name(decode_tags_)   decode_tags_ = sess.run([decode_tags], feed_dict={char_inputs: inputs[1], seg_inputs:inputs[2], dropout:1.0 })

更多细节可以参考 tensorflow/python/saved_model/README.md。

5. 各模式之间的转换

https://zhuanlan.zhihu.com/p/47649285

 

6. 小结

本文总结了 TensorFlow 常见模型格式和载入、保存方法。部署在线服务(Serving)时官方推荐使用 SavedModel 格式,而部署到手机等移动端的模型一般使用 FrozenGraphDef 格式(最近推出的 TensorFlow Lite 也有专门的轻量级模型格式 *.lite,和 FrozenGraphDef 十分类似)。这些格式之间关系密切,可以使用 TensorFlow 提供的 API 来互相转换。

 

推荐阅读