machine-learning - 如何构建 CNN 自动编码器?
问题描述
我有许多文档图像,我想对它们进行聚类以创建类别(发票、收据等)。我想探索图像方法(我知道我可以使用文本),所以我决定构建一个 CNN 自动编码器来将维度压缩到较低的空间,然后运行像 DBSCAN 这样的聚类算法。
我的问题是我不知道如何选择网络层和不同的激活函数等。这是我目前的模型,你怎么看?
model = Sequential()
model.add(Conv2D(16, (3, 3), strides=2, padding='same', kernel_regularizer = l2(), input_shape=image_rgb_dims_top))
model.add(LeakyReLU(alpha=0.2))
#model.add(AveragePooling2D(pool_size=(2,2), padding='same'))
model.add(Conv2D(32, (3, 3), strides=2, padding='same', kernel_regularizer = l2()))
model.add(LeakyReLU(alpha=0.2))
#model.add(AveragePooling2D(pool_size=(2,2), padding='same'))
model.add(Flatten())
model.add(Dense(96, activity_regularizer=l1(10e-6)))
model.add(Dense(np.prod(model.layers[-2].output_shape[1:]),activation='relu'))
model.add(Reshape(model.layers[-4].output_shape[1:]))
model.add(Conv2DTranspose(32,(3, 3), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
#model.add(UpSampling2D((2, 2)))
model.add(Conv2DTranspose(16,(3, 3), strides=(2,2), padding='same'))
model.add(LeakyReLU(alpha=0.2))
#model.add(UpSampling2D((2, 2)))
model.add(Conv2D(1,(3, 3), padding='same'))
model.add(Activation('sigmoid'))
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d (Conv2D) (None, 20, 76, 16) 160 _________________________________________________________________ leaky_re_lu (LeakyReLU) (None, 20, 76, 16) 0 _________________________________________________________________ conv2d_1 (Conv2D) (None, 10, 38, 32) 4640 _________________________________________________________________ leaky_re_lu_1 (LeakyReLU) (None, 10, 38, 32) 0 _________________________________________________________________ flatten (Flatten) (None, 12160) 0 _________________________________________________________________ dense (Dense) (None, 96) 1167456 _________________________________________________________________ dense_1 (Dense) (None, 12160) 1179520 _________________________________________________________________ reshape (Reshape) (None, 10, 38, 32) 0 _________________________________________________________________ conv2d_transpose (Conv2DTran (None, 20, 76, 32) 9248 _________________________________________________________________ leaky_re_lu_2 (LeakyReLU) (None, 20, 76, 32) 0 _________________________________________________________________ conv2d_transpose_1 (Conv2DTr (None, 40, 152, 16) 4624 _________________________________________________________________ leaky_re_lu_3 (LeakyReLU) (None, 40, 152, 16) 0 _________________________________________________________________ conv2d_2 (Conv2D) (None, 40, 152, 1) 145 _________________________________________________________________ activation (Activation) (None, 40, 152, 1) 0 ================================================================= Total params: 2,365,793 Trainable params: 2,365,793 Non-trainable params: 0 _________________________________________________________________
我使用 MSE 和 adam 优化器。
我遇到的问题是:
- 该模型过度拟合数据集中出现最多的图像,因此当它们之间几乎没有差异时,它会为该文档的同一类型创建许多类别(只是添加了一个小徽标,它认为它是一个新的集群)
- 较少出现的图像没有得到足够的学习,我得到一个非常模糊的输出,并且它们中的大多数都被 DBSCAN 聚类为 -1。
知道如何使模型更有效吗?我不希望它过度拟合,但它对某些图像的拟合不足。
什么是好的层/激活函数/regluarizers 使用?我应该增加压缩表示大小还是减小它?很难对网络变化的影响进行基准测试,我所能做的就是运行 dbscan 集群并查看输出类,但这仍然取决于 dbscan epsilon 参数,所以我不知道模型是否做得好或不是。
解决方案
有很多事情可以尝试,在实际尝试之前很难知道它是否有效。
我将首先解决您明确提出的问题。
“模型过度拟合数据集中最常见的图像”
- 您可以尝试使用更大的数据集,或者如果您不能,使用较小的模型可能会奏效。
- 您可以尝试使用预训练模型并在其上运行迁移学习。
- 您可以尝试提前停止。
“较少出现的图像没有得到足够的学习”
- 您可以尝试使用不偏向特定类别的数据集。
“什么是好的层/激活函数/regluarizers 使用?”
- 对于激活函数,ReLU 及其变体大多运行良好。
- 您可以使用各种层和它们的组合。您为什么不尝试使用现代 SOTA(State of the art)CNN 网络的架构作为参考?(您可以在这里轻松找到其中的一些)
“这是我现在的模型,你觉得呢?”
- 至少,架构看起来很旧。如果它运作良好,那就没问题了。但是,如果需要,请尝试使用前面提到的现代 SOTA 架构。
“我应该增加压缩表示大小还是减小它?”
- 目前还不清楚。您应该尝试两者并选择性能更好的方法。
你也可以尝试不同的训练方法!手动标记所有图像将是一场彻头彻尾的噩梦,因此您可以尝试标记其中一些并运行半监督学习。例如)SimCLR
或者,您可以搜索有关文档图像分类的研究并将其用作参考。
希望答案有帮助!
推荐阅读
- java - Java SimpleDateFormat 无法与周年格式一起正常工作
- python - 使用 Python 正则表达式替换特定上下文中的所有匹配项
- python - 不和谐.py | 当机器人离开时,被踢,被服务器禁止
- python - WebScrape 所有相关的 URL/Hrefs
- html - Wordpress 页面标题中的斜体导致浏览器中出现奇怪的 HTML
- maven - AWS CodeCommit 和 Maven 发布插件
- json - json 算不算数据结构?如果不是,那么 JSON 的本质是什么?
- android - issue with addOnCompleteListener solved! -> running into W/System: Ignoring header X-Firebase-Locale because its value was null
- docker - 模拟 docker 节点之间的网络延迟
- android - 为整个屏幕设置 onClickListener