tensorflow - 在不同机器上保存和加载 Universal Sentence Encoder 模型
问题描述
我们如何在不同的机器上保存和加载通用句子编码器模型?
我用 USE 创建了一个 Keras 模型并将其保存在机器 A 上。
from tensorflow.keras.layers import Dense, Dropout, Input
from tensorflow.keras.models import Model, load_model
import tensorflow_hub as hub
import tensorflow as tf
module_url = "/path/on/machine/A/universal-sentence-encoder_4"
emb = hub.KerasLayer(module_url, input_shape=[], dtype=tf.string, trainable=True)
input1 = Input(shape=[], dtype=tf.string)
embedding_layer = emb(input1)
dense1 = Dense(units=512, activation="relu")(embedding_layer)
outputs = Dense(1, activation="sigmoid")(dense1)
model = Model(inputs=input1, outputs=outputs)
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["AUC"])
model.save("model.h5", include_optimizer=False)
现在我想model.h5
在机器 B 上打开。预训练的 USE 保存在这里/different/path/on/machine/B/universal-sentence-encoder_4
。这是我得到的错误。
model = load_model("model.h5", custom_objects={"KerasLayer": hub.KerasLayer})
~/anaconda3/envs/tensorflow/lib/python3.8/site-packages/tensorflow_hub/resolver.py in __call__(self, handle)
494 def __call__(self, handle):
495 if not tf.compat.v1.gfile.Exists(handle):
--> 496 raise IOError("%s does not exist." % handle)
497 return handle
498
OSError: /path/on/machine/A/universal-sentence-encoder_4 does not exist.
我该如何解决这个问题?有没有办法将所有内容保存universal-sentence-encoder_4
到一个model.h5
文件中,这样用户就不必担心使用?
张量流版本:2.4.1
keras 版本:2.4.0
更新:根据WGierke的建议,创建了Google Colab来演示该问题。
解决方案
这里的问题也与您在这个 SO 问题中观察到的有关:保存 Keras 模型需要序列化模型中包含的所有内容。当您hub.KerasLayer
使用通用的 Callable like初始化 a 时loaded_obj
,它不能被序列化。相反,您必须传递一个handle
指向 SavedModel 路径的字符串(或 tfhub.dev URL)。当 Keras 模型被保存时,KerasLayer.get_config
被调用,它将该字符串存储在带有 key 的配置条目中handle
。
恢复 Keras 模型时,配置被反序列化并运行 Keras hub.KerasLayer(config["handle"])
。如您所见,如果存储的模型handle
不再可用,这将失败。
不幸的是,目前唯一的解决方法是确保引用的路径在机器 B 上也可用。
推荐阅读
- typescript - 无法加载 tsc.ps1,因为在此系统上禁用了运行脚本
- jquery - 获取由Jquery mvc4中的foreach循环迭代的div内的值?
- python - 为什么 pandas 在日期索引表中查找日期时会生成 KeyError?
- java - TcpInboundGateway 服务器配置 - 超时时要发送自定义消息
- javascript - Angular Reactive Forms - 验证错误 - 无法读取未定义或 null 的属性
- python - pytest:使用 re.escape() 断言转义字符失败
- javascript - Javascript/Jquery 在 flex 水平滚动 div 中添加鼠标和触摸可拖动/可滑动滚动
- jquery - 旋转外盒时旋转内盒
- shell - 这个 AppleScript/shell 代码是否以任何方式以纯文本形式公开输入?
- firebase - 如何调试 firestore.rules 变量和函数?