首页 > 解决方案 > Keras 2.2:无法加载带有 imagenet 权重的预制模型

问题描述

我有一段代码可以旧版本的 Keras 中工作,但在 Keras 2.2 中,我收到一个错误,将没有足够层的模型加载到更大的模型中:

import keras
from keras.layers import MaxPooling2D, AveragePooling2D,  Conv2D
from keras.applications import Xception
from keras.layers.normalization import BatchNormalization
from keras.layers import Input, Concatenate, Add
from keras.layers.advanced_activations import LeakyReLU

kernel_size = (3, 3)  
pool_size = (2, 2)  
nfilters = 3
inputs = Input(shape=(331, 331, 1))
x = inputs
x = Conv2D(nfilters, kernel_size, strides=(1,1), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.1)(x)
x = MaxPooling2D(pool_size=pool_size)(x)
x =  Add()([x,AveragePooling2D(pool_size=pool_size)(inputs)])  # residual skip connection on shrunk image
base_model = Xception(weights='imagenet', include_top=False, input_tensor=x)

我得到的错误与Xception有关:

ValueError: You are trying to load a weight file containing 80 layers into a model with 82 layers.

是重现此.

加载 imagenet 权重时会出现问题;如果我将权重设置为None没有问题。

load_model()在调用中可以通过传递来避免这种错误by_name=True,但是像 Xception 这样的预制模型不允许使用by_name关键字。

谁能解释如何让我的代码在 Keras 2.2 下再次运行?

我想我可以定义 Xception 两次,一次单独使用 imagenet 权重,另一次在我的完整模型中使用 weights=None,然后将权重从前者复制到后者……但我宁愿不必这样做如果可能的话。

“你为什么将这些层放在 Xception 之前?”这是因为我将较大的图像缩小到 Xception 所需的 imagnet 权重的大小,并将我的灰度图像转换为 3 通道图像。)

标签: pythontensorflowkeras

解决方案


不完全确定如何解释您的错误,但您可以通过将 Xception 模型视为一个层,在之前的层上调用它并将整个堆栈包装在模型实例中来使其工作。我在您的 colab 笔记本中验证了以下内容。

import keras
from keras.layers import MaxPooling2D, AveragePooling2D,  Conv2D
from keras.applications import Xception
from keras.layers.normalization import BatchNormalization
from keras.layers import Input, Concatenate, Add
from keras.layers.advanced_activations import LeakyReLU

kernel_size = (3, 3)  
pool_size = (2, 2)  
nfilters = 3
inputs = Input(shape=(331, 331, 1))
x = inputs
x = Conv2D(nfilters, kernel_size, strides=(1,1), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = LeakyReLU(alpha=0.1)(x)
x = MaxPooling2D(pool_size=pool_size)(x)
x =  Add()([x,AveragePooling2D(pool_size=pool_size)(inputs)])  # residual skip connection on shrunk image

# Xception architecture is just another layer
base_model = Xception(weights='imagenet', include_top=False)
output = base_model(x)
# Wrap everything into a model
combined_model = keras.models.Model(inputs=inputs, outputs=output)

这将为您提供一个如下所示的模型:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            (None, 331, 331, 1)  0                                            
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 331, 331, 3)  27          input_2[0][0]                    
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 331, 331, 3)  12          conv2d_6[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU)       (None, 331, 331, 3)  0           batch_normalization_6[0][0]      
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 165, 165, 3)  0           leaky_re_lu_2[0][0]              
__________________________________________________________________________________________________
average_pooling2d_2 (AveragePoo (None, 165, 165, 1)  0           input_2[0][0]                    
__________________________________________________________________________________________________
add_14 (Add)                    (None, 165, 165, 3)  0           max_pooling2d_2[0][0]            
                                                                 average_pooling2d_2[0][0]        
__________________________________________________________________________________________________
xception (Model)                multiple             20861480    add_14[0][0]                     
==================================================================================================
Total params: 20,861,519
Trainable params: 20,806,985
Non-trainable params: 54,534
__________________________________________________________________________________________________

推荐阅读