python - 将模型另存为 H5 或 SavedModel 时出现 TensorFlow Hub 错误
问题描述
我想使用这个 TF Hub 资产: https ://tfhub.dev/google/imagenet/resnet_v1_50/feature_vector/3
版本:
Version: 1.15.0-dev20190726
Eager mode: False
Hub version: 0.5.0
GPU is available
代码
feature_extractor_url = "https://tfhub.dev/google/imagenet/resnet_v1_50/feature_vector/3"
feature_extractor_layer = hub.KerasLayer(module,
input_shape=(HEIGHT, WIDTH, CHANNELS))
我得到:
ValueError: Importing a SavedModel with tf.saved_model.load requires a 'tags=' argument if there is more than one MetaGraph. Got 'tags=None', but there are 2 MetaGraphs in the SavedModel with tag sets [[], ['train']]. Pass a 'tags=' argument to load this SavedModel.
我试过了:
module = hub.Module("https://tfhub.dev/google/imagenet/resnet_v1_50/feature_vector/3",
tags={"train"})
feature_extractor_layer = hub.KerasLayer(module,
input_shape=(HEIGHT, WIDTH, CHANNELS))
但是当我尝试保存模型时,我得到:
tf.keras.experimental.export_saved_model(model, tf_model_path)
# model.save(h5_model_path) # Same error
NotImplementedError: Can only generate a valid config for `hub.KerasLayer(handle, ...)`that uses a string `handle`.
Got `type(handle)`: <class 'tensorflow_hub.module.Module'>
教程在这里
解决方案
已经有一段时间了,但假设您已经迁移到 TF2,这可以通过最新的模型版本轻松完成,如下所示:
import tensorflow as tf
import tensorflow_hub as hub
num_classes=10 # For example
m = tf.keras.Sequential([
hub.KerasLayer("https://tfhub.dev/google/imagenet/resnet_v1_50/feature_vector/5", trainable=True)
tf.keras.layers.Dense(num_classes, activation='softmax')
])
m.build([None, 224, 224, 3]) # Batch input shape.
# train as needed
m.save("/some/output/path")
如果这对您不起作用,请更新此问题。我相信您的问题是由于hub.Module
与hub.KerasLayer
. 您使用的模型版本是 TF1 Hub 格式,因此在 TF1 中它只能与. 一起使用hub.Module
,而不是与hub.KerasLayer
. 在 TF2 中,hub.KerasLayer
可以直接从其 URL 加载 TF1 Hub 格式模型,以便在更大的模型中组合,但无法对其进行微调。
请参阅此兼容性指南以获取更多信息
推荐阅读
- asp.net-mvc - 是否可以使用 JavaScript 将模型传递给局部视图?
- gradle - Gradle:如何测试特定包但忽略/不包含子包?
- powershell - EventLogPropertySelector 未从 PowerShell 中的事件对象返回扩展数据
- javascript - 如何修复 setTimeout 在新标签后变得更快
- python - 更快地加载 Django 对象
- c# - 无法从 RequestContext 中的 DependencyResolver 获取子容器
- python - Python json.dump 文件包含 %
- r - 我的 tidyverse 没有加载,当我输入 library(tidyverse) 时出现错误,为什么?
- amazon-web-services - 如何让 Kibana 读取指标的自定义索引?
- c++ - UE4:如何在 C++ 中创建 USceneComponent 并将其移动到蓝图视口中