python - 第二个卷积层参数的个数是否正确?
问题描述
对于 MNIST 数据问题,我有简单的 CNN。
cnn_model = tf.keras.Sequential([
tf.keras.layers.Conv2D(filters=24, kernel_size=(3,3), activation='relu'),
tf.keras.layers.Conv2D(filters=36, kernel_size=(3,3), activation='relu'),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation=tf.nn.relu),
tf.keras.layers.Dense(10, activation='softmax')
])
这就是摘要的样子:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_12 (Conv2D) (None, 26, 26, 24) 240
_________________________________________________________________
conv2d_13 (Conv2D) (None, 24, 24, 36) 7812
_________________________________________________________________
flatten_13 (Flatten) (None, 20736) 0
_________________________________________________________________
dense_26 (Dense) (None, 128) 2654336
_________________________________________________________________
dense_27 (Dense) (None, 10) 1290
=================================================================
Total params: 2,663,678
Trainable params: 2,663,678
Non-trainable params: 0
_________________________________________________________________
为了问题的简单性,我跳过了问题中的池层。
第一个卷积层有 240 个参数,易于计算:(内核大小 + 偏差)* 过滤器数量:(3*3+1)*24。请解释一下为什么第二个卷积层有 7812 个参数(36 * 217)。
flatten 层的大小为 20736。这是上一层的 36 个过滤器产生的像素数:24 * 24 * 36。
但是我们如何才能从上一层的 24 张图像中通过 36 个过滤器获得 36 张图像呢?展平层的大小不应该是 36 * 24 * 24 * 24,即前一层的过滤器数量 * 前一层的位图大小 * 第一个卷积层的过滤器数量?
解决方案
卷积层的参数数量为
(filter_height * filter_width * in_channels * out_channels) + out_channels
在你的情况下,那是
(3 * 3 * 24 * 36) + 36 = 7,812
这种卷积的输出形状是
(n_samples, remaining_height, remaining_width, n_filters)
推荐阅读
- java - 抛出 RecyclerView 后 layout_centerInParent 不起作用
- javascript - 如何在 JavaScript 中使用 Promise 在一段时间后打印日志
- javascript - 如何使用动画css3将div放入框中?
- django - Gunicorn ModuleNotFoundError:没有名为“django”的模块
- spring-integration - Spring集成服务激活器使用基于输入的bean名称
- laravel - 使用 Ajax 刷新 Laravel 图表
- c# - EF Core 中的合并冲突
- android - 使用 Google Maps Activity 创建项目时出现 Android Studio 错误
- javascript - 如何从 URL 加载 JSON 并在 HTML 页面上使用纯 javascript 显示?
- android - 如何从 decodeImage(File('assets/logo.png').readAsBytesSync()) 加载图像;飘飘然