keras - 为什么我的自动编码器模型没有学习?
问题描述
我正在尝试使用自动编码器解决验证码数据集。数据集是RGB 图像。
我将RGB图像转换为一个通道,即:
(图像的形状为 (48, 200))。
所以我接下来要做的是使用验证码的文本(在我们的例子中是“emwpn”),并用这个文本创建另一个具有相同形状(48、200)的图像,即:
我尝试的是为自动编码器的编码器提供验证码,并为解码器提供我创建的图像。
不知道这个方法会不会好用,没想到什么都学不到。当我试图预测测试数据集时,我得到的只是紫色图像,即:
capchas_array_test_pred = conv_ae.predict(capchas_array_test)
plt.imshow(capchas_array_test_pred[1])
这意味着自动编码器为所有图像的所有像素预测 0。
这是 conv 自动编码器的代码:
def rounded_accuracy(y_true, y_pred):
return keras.metrics.binary_accuracy(tf.round(y_true), tf.round(y_pred))
conv_encoder = keras.models.Sequential([
keras.layers.Reshape([48, 200, 1], input_shape=[48, 200]),
keras.layers.Conv2D(16, kernel_size=5, padding="SAME"),
keras.layers.BatchNormalization(),
keras.layers.Activation("relu"),
keras.layers.Conv2D(32, kernel_size=5, padding="SAME", activation="selu"),
keras.layers.Conv2D(64, kernel_size=5, padding="SAME", activation="selu"),
keras.layers.AvgPool2D(pool_size=2),
])
conv_decoder = keras.models.Sequential([
keras.layers.Conv2DTranspose(32, kernel_size=5, strides=2, padding="SAME", activation="selu",
input_shape=[6, 25, 64]),
keras.layers.Conv2DTranspose(16, kernel_size=5, strides=1, padding="SAME", activation="selu"),
keras.layers.Conv2DTranspose(1, kernel_size=5, strides=1, padding="SAME", activation="sigmoid"),
keras.layers.Reshape([48, 200])
])
conv_ae = keras.models.Sequential([conv_encoder, conv_decoder])
conv_ae.compile(loss="mse", optimizer=keras.optimizers.Adam(lr=1e-1), metrics=[rounded_accuracy])
history = conv_ae.fit(capchas_array_train, capchas_array_rewritten_train, epochs=20,
validation_data=(capchas_array_valid, capchas_array_rewritten_valid))
该模型没有学到任何东西:
Epoch 2/20
24/24 [==============================] - 1s 53ms/step - loss: 60879.9883 - rounded_accuracy: 0.0637 - val_loss: 60930.7344 - val_rounded_accuracy: 0.0635
Epoch 3/20
24/24 [==============================] - 1s 53ms/step - loss: 60878.5781 - rounded_accuracy: 0.0637 - val_loss: 60930.7344 - val_rounded_accuracy: 0.0635
Epoch 4/20
24/24 [==============================] - 1s 53ms/step - loss: 60879.2656 - rounded_accuracy: 0.0637 - val_loss: 60930.7344 - val_rounded_accuracy: 0.0635
Epoch 5/20
24/24 [==============================] - 1s 53ms/step - loss: 60876.4648 - rounded_accuracy: 0.0637 - val_loss: 60930.7344 - val_rounded_accuracy: 0.0635
Epoch 6/20
24/24 [==============================] - 1s 53ms/step - loss: 60878.4883 - rounded_accuracy: 0.0637 - val_loss: 60930.7344 - val_rounded_accuracy: 0.0635
Epoch 7/20
24/24 [==============================] - 1s 53ms/step - loss: 60880.8242 - rounded_accuracy: 0.0637 - val_loss: 60930.7344 - val_rounded_accuracy: 0.0635
我试图检查如果我为编码器和解码器提供相同的图像会发生什么检查:
conv_ae.compile(loss="mse", optimizer=keras.optimizers.Adam(lr=1e-1), metrics=[rounded_accuracy])
history = conv_ae.fit(capchas_array_train, capchas_array_train, epochs=20,
validation_data=(capchas_array_valid, capchas_array_valid))
我又得到了紫色图像:
Ps 如果你有兴趣,这是笔记本: https ://colab.research.google.com/drive/1gA1XN1NOZKylGDhVu4PKXWhrPU4q9Ady
编辑-
这是我对图像进行的预处理:
1. Convert RGB image to one channel.
2. Normalize the image from value from 0 to 255 for each pixel, to 0 to 1.
3. Resize the (50, 200) image to (48, 200) - for simpler pooling in the autoencoder (48 can be divided by 2 more times, and stay integer, than 50)
这是预处理 1,2 步骤的功能:
def rgb2gray(rgb):
r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2]
gray = (0.2989 * r + 0.5870 * g + 0.1140 * b)
for x in range(rgb.shape[1]):
for y in range(rgb.shape[0]):
if gray[y][x]>128:
gray[y][x] = 1.0
else:
gray[y][x] = 0.0
return gray
解决方案
- 你的架构没有任何意义。如果您想创建一个自动编码器,您需要了解您将在编码后反转过程。这意味着如果你有三个卷积层,过滤器的顺序是:64、32、16;你应该做下一组卷积层来做相反的事情:16、32、64。这就是你的算法没有学习的原因。
- 你不会得到你期望的结果。您将获得与这种验证码类似的结构,但您不会清楚地输出文本。如果需要,您需要另一种算法(允许您进行字符分割的算法)。
推荐阅读
- java - 如何更改 MaterialAlertDialog 主题
- mongodb - 如何在mongodb中获取单个数组中的数据?
- r - 使用 R,如何开发一个名为 `setOption` 的函数?
- html - 使用 Bootstrap Modal 选择下拉列表 Z 索引问题
- python - 从 Python Pandas / Dask 中的 Parquet 文件中读取一组行?
- google-oauth - 如何将社区连接器(应用程序脚本)连接到 GA4 属性,以便它像高级服务和 UA 一样运行?
- python - 如何将填充有数字的字符串转换为列表中的单独整数?
- java - 如何在 JPanel 和 JFrame 的 contentPane 之间添加间距?
- html - CSS类继承?
- python - 更改数据框的索引:获取属性错误