tensorflow - 在 {TF 2.0.0-beta1 上使用 tflite_convert 时出现“未知(自定义)损失函数”;Keras} 模型
问题描述
概括
我的问题由以下组成:
- 我展示我的项目、我的工作环境和我的工作流程的背景
- 详细的问题
- 我的代码的相关部分
- 我试图解决我的问题的解决方案
- 问题提醒
语境
我已经编写了原始超分辨率 GAN 的降级版本的 Python Keras 实现。现在我想使用 Google Firebase 机器学习工具包对其进行测试,方法是将其托管在 Google 服务器中。这就是为什么我必须将我的 Keras 程序转换为 TensorFlow Lite 的原因。
环境和工作流程(有问题)
我正在 Google Colab 工作环境中训练我的程序:在那里,我已经安装了TF 2.0.0-beta1
(这个选择是由这个不正确的答案引起的:https ://datascience.stackexchange.com/a/57408/78409 )。
工作流程(和问题):
我在本地编写我的 Python Keras 程序,记住它将在 TF 2 上运行。所以我使用 TF 2 导入,例如
from tensorflow.keras.optimizers import Adam
:from tensorflow.keras.layers import Conv2D, BatchNormalization
我将代码发送到我的云端硬盘
我运行我的 Google Colab Notebook: TF 2 没有任何问题。
我在我的驱动器中获得了输出模型,然后下载了它。
我尝试通过执行以下 CLI 将此模型转换为 TFLite 格式
tflite_convert --output_file=srgan.tflite --keras_model_file=srgan.h5
::这里出现了问题。
问题
之前的 CLI 没有从 TF (Keras) 模型输出 TF Lite 转换后的模型,而是输出此错误:
ValueError:未知损失函数:build_vgg19_loss_network
该函数build_vgg19_loss_network
是我实现的自定义损失函数,GAN 必须使用它。
引发此问题的部分代码
呈现自定义损失函数
自定义损失函数是这样实现的:
def build_vgg19_loss_network(ground_truth_image, predicted_image):
loss_model = Vgg19Loss.define_loss_model(high_resolution_shape)
return mean(square(loss_model(ground_truth_image) - loss_model(predicted_image)))
使用我的自定义损失函数编译生成器网络
generator_model.compile(optimizer=the_optimizer, loss=build_vgg19_loss_network)
我为解决问题而尝试做的事情
当我在 StackOverflow(本问题开头的链接)上阅读它时,TF 2 被认为足以输出一个可以由我的
tflite_convert
CLI 正确处理的 Keras 模型。但显然不是。当我在 GitHub 上阅读它时,我尝试通过添加以下行在 Keras 的损失函数中手动设置我的自定义损失函数
import tensorflow.keras.losses tensorflow.keras.losses.build_vgg19_loss_network = build_vgg19_loss_network
:它没有用。我在 GitHub 上读到我可以使用带有
load_model
Keras 函数的自定义对象:但我只想使用compile
Keras 函数。不是load_model
。
我的最后一个问题
我只想对我的代码做一些小的改动,因为它工作正常。因此,例如,我不想替换compile
为load_model
. 有了这个约束,你能帮我让我的 CLItflite_convert
与我的自定义损失函数一起工作吗?
解决方案
由于您声称 TFLite 转换由于自定义损失函数而失败,因此您可以保存模型文件而不保留优化器详细信息。为此,请将include_optimizer
参数设置为 False,如下所示:
model.save('model.h5', include_optimizer=False)
现在,如果模型中的所有层都是可转换的,它们应该被转换为 TFLite 文件。
编辑:然后您可以像这样转换 h5 文件:
import tensorflow as tf
model = tf.keras.models.load_model('model.h5') # srgan.h5 for you
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
此处记录了克服 TFLite 转换中不受支持的运算符的常规做法。
推荐阅读
- python - Raspberry PI 3B+ 上的 OpenVino 和 PyInstaller
- sql - 为什么我的存储过程需要很长时间才能提供数据
- python - 列表的第一个元素和最后一个元素,Python
- sql-server - 如何将特定的列和行数据插入到现有表中,从一个表到另一个表?
- android - 分阶段推出:人们仍然可以进入 google playstore 下载新版本吗?
- python - 正则表达式代码以捕获包含日期的 whatsapp 消息
- typescript - 如何遍历在 NestJs Bull 中注册的所有队列?
- r - Shiny:renderPrint 一个不带 [1] 且不带转义字符的字符串
- javascript - 两个浮点数之和返回错误值
- c# - 通过按钮单击C#减少变量