deep-learning - 构建自动编码器时收到错误
问题描述
我正在尝试使用 CNN 作为编码器和 LSTM 作为解码器为我的学期项目构建一个自动编码器,但是当我显示模型的摘要时。我收到以下错误:
ValueError:输入 0 与层 lstm_10 不兼容:预期 ndim=3,发现 ndim=2
x.shape = (45406, 100, 100)
y.shape = (45406,)
我已经尝试改变 LSTM 的输入形状,但没有奏效。
def keras_model(image_x, image_y):
model = Sequential()
model.add(Lambda(lambda x: x / 127.5 - 1., input_shape=(image_x, image_y, 1)))
last = model.output
x = Conv2D(3, (3, 3), padding='same')(last)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling2D((2, 2), padding='valid')(x)
encoded= Flatten()(x)
x = LSTM(8, return_sequences=True, input_shape=(100,100))(encoded)
decoded = LSTM(64, return_sequences = True)(x)
x = Dropout(0.5)(decoded)
x = Dense(400, activation='relu')(x)
x = Dense(25, activation='relu')(x)
final = Dense(1, activation='relu')(x)
autoencoder = Model(model.input, final)
autoencoder.compile(optimizer="Adam", loss="mse")
autoencoder.summary()
model= keras_model(100, 100)
解决方案
鉴于您使用的是 LSTM,您需要一个时间维度。所以你的输入形状应该是:(时间,image_x,image_y,nb_image_channels)。
我建议更深入地了解自动编码器、LSTM 和 2D 卷积,因为所有这些都在这里发挥作用。这是一个有用的介绍:https ://machinelearningmastery.com/lstm-autoencoders/和这个https://blog.keras.io/building-autoencoders-in-keras.html)。
也看看这个例子,有人用 Conv2D 实现了一个 LSTM如何重塑 3 通道数据集以输入到神经网络。TimeDistributed 层在这里很有用。
但是,为了解决您的错误,您可以添加一个 Reshape() 层来伪造额外的维度:
def keras_model(image_x, image_y):
model = Sequential()
model.add(Lambda(lambda x: x / 127.5 - 1., input_shape=(image_x, image_y, 1)))
last = model.output
x = Conv2D(3, (3, 3), padding='same')(last)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling2D((2, 2), padding='valid')(x)
encoded= Flatten()(x)
# (50,50,3) is the output shape of the max pooling layer (see model summary)
encoded = Reshape((50*50*3, 1))(encoded)
x = LSTM(8, return_sequences=True)(encoded) # input shape can be removed
decoded = LSTM(64, return_sequences = True)(x)
x = Dropout(0.5)(decoded)
x = Dense(400, activation='relu')(x)
x = Dense(25, activation='relu')(x)
final = Dense(1, activation='relu')(x)
autoencoder = Model(model.input, final)
autoencoder.compile(optimizer="Adam", loss="mse")
print(autoencoder.summary())
model= keras_model(100, 100)
推荐阅读
- php - PHP:如何通过计算日期的出现来从另一个数组创建一个新的数据数组?
- json - 用于 Gmail 注释的 CatalogCardLayout(轮播图像)
- php - Laravel 将用户同步到所有可以不受限制的类别
- google-apps-script - 从工作表中的某个点获取第一列
- python - Python Crash Course Game 显示未正确显示图像
- scala - 如何在 Scala 中模拟 JDBC 结果集
- apache-spark - Spark on Kubernetes 故障排除
- typescript - 如果构造函数是私有的,则评估参数是否为“泛型类的实例”的泛型
- reactjs - 滚动到页面上某个组件的“反应”方式是什么?
- java - 使用 JSON 识别 JsonPrimitve 是 BigDecimal 还是 Integer