python - create_module_spec - tfhub
问题描述
我正在尝试创建一个简单的音频识别来发现关键词。由于我的数据集很小,我正在执行迁移学习。这就是图表的外观。按照这个链接我创建了一个模块。这是代码
import tensorflow_hub as hub
import tensorflow as tf
# pylint: disable=unused-import
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
# pylint: enable=unused-import
def module_fn():
input_name = "Reshape:0"
output_name = "Reshape_2:0"
graph_def = tf.GraphDef()
with open('my_frozen_graph.pb', "rb") as f:
graph_def.ParseFromString(f.read())
input_ten=tf.placeholder(tf.float32, shape = (1, 98, 40))
output_ten,=tf.import_graph_def(graph_def, input_map = {input_name: input_ten}, return_elements = [output_name])
hub.add_signature(inputs = input_ten, outputs = output_ten)
spec = hub.create_module_spec(module_fn)
module = hub.Module(spec)
with tf.Session() as session:
module.export('test_module',session)
尽管它确实执行并创建了一个“test_module”文件夹。
test_module
|--> assets
|--> variables
|--> saved_model.pb
|--> tfhub_module.pb
我怎么有几个问题
变量文件夹为空。不确定这是否应该是这样的?
input_ten=tf.placeholder(tf.float32, shape = (1, 98, 40))
这个对吗 ?98X48 是图像大小,第一个元组通常代表批量大小。它应该保持为 '1' 还是未知的批量大小 'None' ?将模块加载到脚本后
高度,宽度 = hub.get_expected_image_size('test_module')
给我一个错误。
解决方案
让我试着依次回答你的问题。
如果您从中构建模型的图形定义确实被冻结(即,所有变量都已被常量替换),则没有变量需要写入通常位于变量/变量*的检查点。所以这对我来说是可以解释的。-- 也就是说,Hub 模块将为您提供一种避免冻结图定义的方法:在 module_fn 中调用原始图构建代码,并在调用 Module.export() 之前在会话中恢复预训练的变量。
对于您的模块类型,您可以制定规则。;-) 集线器模块可以适应各种输入和输出形状,包括部分或完全未知的形状。像上面这样的输入占位符必须具有与您插入的图形兼容的形状。反过来,该图将使用与其正在执行的卷积一起工作的形状。一般来说,将前导维度用于批量大小并保留未指定 (
None
) 通常很有用。hub.get_expected_image_size() 用于图像输入。我会在这里避免它。
推荐阅读
- java - Swagger 开放 API 定义不适用于 Micronaut JWT 安全 Micronaut 版本 2.2.1
- php - PHP 在一个字符串之后,在一个字符串内返回一个字符串
- ruby-on-rails - Rails 6 注册模式与设计
- php - Wordpress 5.6 不支持 PHP 8?
- java - 使用 guice 绑定一个番石榴供应商
- r - 收集多个列
- postgresql - Postgres 查询每个用户超过 30 天的登录次数,每天只计算一次
- python - OpenCV 太慢了
- java - 如何使用 Java 身份验证连接到 Cassandra 5.1
- selenium-webdriver - 读取 excel 文件 - 如何读取不同的工作表名称