tensorflow - 从 TensorFlow Estimator 模型 (2.0) 保存、加载和预测
问题描述
Estimator
在 TF2 中是否有序列化和恢复模型的指南?文档参差不齐,其中大部分没有更新到 TF2。我还没有在任何地方看到一个清晰且完整的示例,Estimator
可以保存、从磁盘加载并用于根据新输入进行预测。
TBH,我对这看起来有多复杂感到有些困惑。估计器被称为拟合标准模型的简单、相对高级的方法,但在生产中使用它们的过程似乎非常神秘。例如,当我从磁盘加载模型时,tf.saved_model.load(export_path)
我得到一个AutoTrackable
对象:
<tensorflow.python.training.tracking.tracking.AutoTrackable at 0x7fc42e779f60>
不清楚为什么我不Estimator
回来。看起来曾经有一个听起来很有用的函数tf.contrib.predictor.from_saved_model
,但由于contrib
它消失了,它似乎不再发挥作用(除了它出现在 TFLite 中)。
任何指针都会非常有帮助。正如你所看到的,我有点迷路了。
解决方案
也许作者不再需要答案,但我能够使用 TensorFlow 2.1 保存和加载 DNNClassifier
# training.py
from pathlib import Path
import tensorflow as tf
....
# Creating the estimator
estimator = tf.estimator.DNNClassifier(
model_dir = <model_dir>,
hidden_units = [1000, 500],
feature_columns = feature_columns, # this is a list defined earlier
n_classes = 2,
optimizer = 'adam')
feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)
export_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)
servable_model_path = Path(estimator.export_saved_model(<model_dir>, export_input_fn).decode('utf8'))
print(f'Model saved at {servable_model_path}')
对于加载,您找到了正确的方法,您只需要检索predict_fn
# testing.py
import tensorflow as tf
import pandas as pd
def predict_input_fn(test_df):
'''Convert your dataframe using tf.train.Example() and tf.train.Features()'''
examples = []
....
return tf.constant(examples)
test_df = pd.read_csv('test.csv', ...)
# Loading the estimator
predict_fn = tf.saved_model.load(<model_dir>).signatures['predict']
# Predict
predictions = predict_fn(examples=predict_input_fn(test_df))
希望这也可以帮助其他人(:
推荐阅读
- python - 使用 Cartopy 绘制跨越国际日期变更线的线
- python - 如何在 PandasSql 中读取嵌套的 Json 文件
- python - 如何根据输入重复几行代码
- .htaccess - .htaccess ModRewrite 带 2 个参数
- tailwind-css - prettier 2.3 版中类属性的基于前缀的多行格式
- c# - 在 C# 中将 Base64 图像添加到邮件中
- java - Mac OS Catalina 10.15.7 (19H1030) 获取 java.io.IOException 错误=7,运行构建时参数列表太长
- c# - Autodesk Forge 尝试在线访问 API 时出错
- r - 如果给定单元格大于 100,000,则使用 M(百万)后缀格式化数字
- excel - Excel - 如何搜索特定子字符串的范围,然后从找到子字符串的单元格中输出整个值?