python - Keras InceptionV3 TypeError: unhashable type: 'Dimension'
问题描述
我正在尝试实现一个模型,它将灰度图像作为输入并返回一个数值作为输出。我使用 InceptionV3(从头开始训练)作为特征提取器,然后使用一些密集层进行最后阶段的回归。
这是我的代码:
from keras.applications.inception_v3 import InceptionV3
from keras.layers import Input, GlobalAveragePooling2D, Dense, Dropout, Flatten, BatchNormalization
from keras.models import Model
from keras.metrics import mean_absolute_error
from keras.utils import plot_model
inputs = Input(shape=(256, 256, 1))
x = BatchNormalization()(inputs)
x = InceptionV3(include_top = False, weights = None, input_shape=inputs.shape[1:])(x)
x = BatchNormalization()(x)
x = GlobalAveragePooling2D()(x)
x = Dense(1000, activation = 'relu' )(x)
x = Dense(1000, activation = 'relu' )(x)
outputs = Dense(1, activation = 'linear' )(x)
model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer = 'adam', loss = 'mse', metrics = [mae])
model.summary()
现在,当我运行代码时出现此错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-36-50041eb640cc> in <module>()
7 inputs = Input(shape=(256, 256, 1))
8 x = BatchNormalization()(inputs)
----> 9 x = InceptionV3(include_top = False, weights = None, input_shape=inputs.shape[1:])(x)
10 x = BatchNormalization()(x)
11 x = GlobalAveragePooling2D()(x)
3 frames
/usr/local/lib/python3.6/dist-packages/keras_applications/imagenet_utils.py in _obtain_input_shape(input_shape, default_size, min_size, data_format, require_flatten, weights)
273 default_shape = (input_shape[0], default_size, default_size)
274 else:
--> 275 if input_shape[-1] not in {1, 3}:
276 warnings.warn(
277 'This model usually expects 1 or 3 input channels. '
TypeError: unhashable type: 'Dimension'
我不明白是什么导致了错误,因为当我使用顺序模型时它绝对没问题。但它不适用于这个功能模型。
解决方案
inputs.shape
不是列表,因此会引发错误。它为您提供带有类型的形状,tensorflow.python.framework.tensor_shape.TensorShape
其中包含带有类型的每个维度的列表Dimension
print(inputs.shape)
# output TensorShape([Dimension(None), Dimension(256), Dimension(256), Dimension(1)])
您可以使用as_list()
获取形状作为列表:
# inputs.shape.as_list()
# output [None, 256, 256, 1]
x = InceptionV3(include_top = False, weights = None, input_shape=inputs.shape.as_list()[1:])(x)
推荐阅读
- javascript - 如果我有多个表单,如何在表格行中显示元素
- java - 检查文本字段是否有多个点
- javascript - 如何迭代哈希映射以响应该映射包含 3 个列表的 ajax 调用
- java - 使用 Firebase 中的键显示特定数据
- python-3.x - 通过传递相关方法在两个熊猫数据框列之间建立相关性
- terminal - anaconda spyder 运行时如何运行`conda`命令?
- html - 为什么我的菜单不显示与代码相同的行?
- ios - 如何根据段宽调整 UISegmentedControl 字体大小?
- racket - 字符串前缀?\在模块中获取错误未绑定标识符
- python - numpy:用 2d 插值调整 3d 数组的大小