首页 > 解决方案 > 如何构建 CNN 自动编码器?

问题描述

我有许多文档图像,我想对它们进行聚类以创建类别(发票、收据等)。我想探索图像方法(我知道我可以使用文本),所以我决定构建一个 CNN 自动编码器来将维度压缩到较低的空间,然后运行像 DBSCAN 这样的聚类算法。

我的问题是我不知道如何选择网络层和不同的激活函数等。这是我目前的模型,你怎么看?

model = Sequential()

model.add(Conv2D(16, (3, 3), strides=2, padding='same', kernel_regularizer = l2(), input_shape=image_rgb_dims_top))
model.add(LeakyReLU(alpha=0.2))
#model.add(AveragePooling2D(pool_size=(2,2), padding='same'))

model.add(Conv2D(32, (3, 3), strides=2, padding='same', kernel_regularizer = l2()))
model.add(LeakyReLU(alpha=0.2))
#model.add(AveragePooling2D(pool_size=(2,2), padding='same'))


model.add(Flatten())
model.add(Dense(96, activity_regularizer=l1(10e-6)))
model.add(Dense(np.prod(model.layers[-2].output_shape[1:]),activation='relu'))
model.add(Reshape(model.layers[-4].output_shape[1:]))

model.add(Conv2DTranspose(32,(3, 3), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
#model.add(UpSampling2D((2, 2)))

model.add(Conv2DTranspose(16,(3, 3), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
#model.add(UpSampling2D((2, 2)))

model.add(Conv2D(1,(3, 3), padding='same'))
model.add(Activation('sigmoid'))
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 20, 76, 16)        160       
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 20, 76, 16)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 10, 38, 32)        4640      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 10, 38, 32)        0         
_________________________________________________________________
flatten (Flatten)            (None, 12160)             0         
_________________________________________________________________
dense (Dense)                (None, 96)                1167456   
_________________________________________________________________
dense_1 (Dense)              (None, 12160)             1179520   
_________________________________________________________________
reshape (Reshape)            (None, 10, 38, 32)        0         
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 20, 76, 32)        9248      
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 20, 76, 32)        0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 40, 152, 16)       4624      
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 40, 152, 16)       0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 40, 152, 1)        145       
_________________________________________________________________
activation (Activation)      (None, 40, 152, 1)        0         
=================================================================
Total params: 2,365,793
Trainable params: 2,365,793
Non-trainable params: 0
_________________________________________________________________

我使用 MSE 和 adam 优化器。

我遇到的问题是:

知道如何使模型更有效吗?我不希望它过度拟合,但它对某些图像的拟合不足。

什么是好的层/激活函数/regluarizers 使用?我应该增加压缩表示大小还是减小它?很难对网络变化的影响进行基准测试,我所能做的就是运行 dbscan 集群并查看输出类,但这仍然取决于 dbscan epsilon 参数,所以我不知道模型是否做得好或不是。

标签: machine-learningdata-sciencecluster-analysisautoencoderconv-neural-network

解决方案


有很多事情可以尝试,在实际尝试之前很难知道它是否有效。

我将首先解决您明确提出的问题。

“模型过度拟合数据集中最常见的图像”

  • 您可以尝试使用更大的数据集,或者如果您不能,使用较小的模型可能会奏效。
  • 您可以尝试使用预训练模型并在其上运行迁移学习。
  • 您可以尝试提前停止。

“较少出现的图像没有得到足够的学习”

  • 您可以尝试使用不偏向特定类别的数据集。

“什么是好的层/激活函数/regluarizers 使用?”

  • 对于激活函数,ReLU 及其变体大多运行良好。
  • 您可以使用各种层和它们的组合。您为什么不尝试使用现代 SOTA(State of the art)CNN 网络的架构作为参考?(您可以在这里轻松找到其中的一些)

“这是我现在的模型,你觉得呢?”

  • 至少,架构看起来很旧。如果它运作良好,那就没问题了。但是,如果需要,请尝试使用前面提到的现代 SOTA 架构。

“我应该增加压缩表示大小还是减小它?”

  • 目前还不清楚。您应该尝试两者并选择性能更好的方法。

你也可以尝试不同的训练方法!手动标记所有图像将是一场彻头彻尾的噩梦,因此您可以尝试标记其中一些并运行半监督学习。例如)SimCLR

或者,您可以搜索有关文档图像分类的研究并将其用作参考。

希望答案有帮助!


推荐阅读