首页 > 解决方案 > 如果我无法使用 TensorFlow 下载预训练模型,如何手动加载它

问题描述

我正在尝试通过 TensorFlow 下载 VGG19 模型

base_model = VGG19(input_shape = [256,256,3],
                    include_top = False,
                    weights = 'imagenet')

但是,下载总是在完成下载之前卡住。我也尝试过不同的模型,比如 InceptionV3,同样的情况也发生在那里。

幸运的是,提示提供了可以手动下载模型的链接

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5
19546112/80134624 [======>.......................] - ETA: 11s

从给定链接下载模型后,我尝试使用导入模型

base_model = load_model('vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5')

但我得到这个错误

ValueError: No model found in config file.

如何手动加载下载的 .h5 模型?

标签: tensorflowkerasdeep-learningtransfer-learningvgg-net

解决方案


您正在使用load_model权重,而不是模型。您需要先定义模型,然后加载权重。

weights = "path/to/weights"
model = VGG19  # the defined model
model.load_weights(weights)  # the weights

推荐阅读