首页 > 解决方案 > TensorFlow 2.0 和 TensorFlow Hub:load_module_spec 等效?

问题描述

当使用 TensorFlow 1.x 和 TensorFlow hub 时,我们可以加载模块的规范来检查预期的输出形状(可能还有其他有用的规范!),如下所示:

spec = hub.load_module_spec("https://tfhub.dev/google/nnlm-en-dim128/1")
shape = spec.get_output_info_dict()['default'].get_shape()

当尝试对兼容 TF 2.0 的集线器模块执行相同操作时,我在调用时遇到以下错误消息load_module_spec

缺少支持的实现:loader(*('/tmp/tfhub_modules/82c4aaf4250ffb09088bd48368ee7fd00e5464fe',), **{})

是否有其他方法可以检查 TF 2.0 集线器模块的输出形状?

标签: pythontensorflowtensorflow-hub

解决方案


对于 TensorFlow 2,TF Hub 将切换到提供 TF2 的原生基于对象的 SavedModels [ docRFC ]。tf.saved_model.load()如果它们已经在您的文件系统上,或者hub.load()从 URL 可选下载,则它们会被加载。这为您提供了一个恢复的Trackable对象,其__call__成员的行为类似于 a @tf.function,这意味着它具有一个或多个具体函数,每个函数都由 TF 图支持,并根据张量形状/dtypes 和非张量参数在它们之间分派。

在 TF2 的当前 alpha 版本中,如果您知道输入的允许 TensorSpec,您可以深入到输出,例如:

loaded_model = hub.load("https://tfhub.dev/google/tf2-preview/nnlm-en-dim128/1")
concrete_function = loaded_model.__call__.get_concrete_function(
    tf.TensorSpec((None,), tf.string))
print(concrete_function.output_shapes, ":",
      concrete_function.output_dtypes)

这给了我

(None, 128) : <dtype: 'float32'>

推荐阅读