python - 从自动编码器获取瓶颈层的输出
问题描述
我是自动编码器的新手。我已经构建了一个简单的卷积自动编码器,如下所示:
# ENCODER
input_img = Input(shape=(64, 64, 1))
encode1 = Conv2D(32, (3, 3), activation=tf.nn.leaky_relu, padding='same')(input_img)
encode2 = MaxPooling2D((2, 2), padding='same')(encode1)
l = Flatten()(encode2)
l = Dense(100, activation='linear')(l)
# DECODER
d = Dense(1024, activation='linear')(l)
d = Reshape((32,32,1))(d)
decode3 = Conv2D(64, (3, 3), activation=tf.nn.leaky_relu, padding='same')(d)
decode4 = UpSampling2D((2, 2))(decode3)
model = models.Model(input_img, decode4)
model.compile(optimizer='adam', loss='mse')
# Train it by providing training images
model.fit(x, y, epochs=20, batch_size=16)
现在在训练这个模型之后,我想从瓶颈层(即密集层)获得输出。这意味着如果我将形状数组 (1000, 64, 64) 扔到模型中,我想要压缩的形状数组 (1000, 100)。
我尝试了如下所示的一种方法,但它给了我一些错误。
model = Model(inputs=[x], outputs=[l])
呃:
ValueError: Input tensors to a Functional must come from `tf.keras.Input`.
我也尝试了其他一些方法,但这也不起作用。有人可以告诉我如何在训练模型后恢复压缩数组。
解决方案
您需要为encoder
. 训练整个系统后encoder-decoder
,只能encoder
用于预测。代码示例:
# ENCODER
input_img = layers.Input(shape=(64, 64, 1))
encode1 = layers.Conv2D(32, (3, 3), activation=tf.nn.leaky_relu, padding='same')(input_img)
encode2 = layers.MaxPooling2D((2, 2), padding='same')(encode1)
l = layers.Flatten()(encode2)
encoder_output = layers.Dense(100, activation='linear')(l)
# DECODER
d = layers.Dense(1024, activation='linear')(encoder_output)
d = layers.Reshape((32,32,1))(d)
decode3 = layers.Conv2D(64, (3, 3), activation=tf.nn.leaky_relu, padding='same')(d)
decode4 = layers.UpSampling2D((2, 2))(decode3)
model_encoder = Model(input_img, encoder_output)
model = Model(input_img, decode4)
model.fit(X, y, epochs=20, batch_size=16)
model_encoder.predict(X)
应该为每个图像返回一个向量。
推荐阅读
- mysql - 在 MySQL 中使用 CHECK 和 LIKE?
- vue.js - 如何更改 nuxt.js 中的 vue-intro.js css
- web-crawler - 微服务可以根据资源使用情况进行拆分吗?
- java - 以编程方式设置约束布局的比例 x 和 y 以适应任何设备中的屏幕
- app-store - 自制的 Android 应用程序不适合所有人,仅供公司使用
- sql - 使用 .net 从 MS Access 数据库中删除不存在的记录
- nextcloud - 无法上网 Nextcloud Connection Assistant
- python - Asyncio 服务器 - 如何通过访问 recipent 客户端的 writer 对象在两个客户端之间正确传达消息?
- django - 我是否必须使用 Django 模型才能使用 API?
- django - keep_lazy 用于通过 Django API 发送的翻译消息