tensorflow - Tensorflow Keras Conv2D 多个过滤器
问题描述
如果我有一个 1X2X3X3 输入(我首先使用通道)和权重 2X2X2X2 如下图所示,我不太了解 Keras Conv2D 输出,有人可以帮我理解输出特征图,过滤器如何对输入进行卷积得到输出?
这是我的代码:
import os
import tensorflow as to
import tensorflow.python.util.deprecation as deprecation
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv1D, Conv2D
data = tf.range(3 * 3 * 2)
print(data)
data = tf.reshape(data, (1, 2, 3, 3))
print(data)
print('-------')
e = tf.range(2 * 2 * 2 * 2)
print(e)
e = tf.reshape(e, (2, 2, 2, 2))
print(e)
print('-------')
model = Sequential()
model.add(Conv2D(2, (2, 2), input_shape=(2, 3, 3), data_format='channels_first'))
weights = [e, tf.constant([0.0,0.0])]
model.set_weights(weights)
print(model.get_weights())
yhat = model.predict(data)
print(yhat.shape)
print(yhat)
解决方案
如果您在查看每个操作员时改变视角,则更容易理解。您有一个形状为 1x2x3x3 的输入。由于您使用的是data_format='channels_first'
,这意味着您有 1 个具有 2 个通道且大小为 3x3 的图像。您可以像这样可视化该图像:
| [ 0 9] [ 1 10] [ 2 11] |
| [ 3 12] [ 4 13] [ 5 14] |
| [ 6 15] [ 7 16] [ 8 17] |
这是您的 3x3 图像,其中每个“像素”有两个通道。过滤器形状为 2x2x2x2,这意味着 2x2 过滤器从 2 个通道变为 2 个通道。这可以这样表示:
| 0 1 | | 4 5 |
| 2 3 | | 6 7 |
| 8 9 | | 12 13 |
| 10 11 | | 14 15 |
这是您的 2x2 过滤器,其中每个过滤器位置包含一个 2x2 矩阵。结果,形状为 1x2x2x2,是 1 张具有 2 个通道且大小为 2x2 的图像:
| [456 508] [512 571] |
| [624 700] [680 764] |
为了理解操作是如何工作的,我将介绍输出的第一个“像素”的计算,[456 508]
. 此输出是从输入图像中的第一个 2x2 窗口计算得出的:
| [ 0 9] [ 1 10] |
| [ 3 12] [ 4 13] |
您要做的是获取每个“像素”(二元素向量)并将它们乘以过滤器中相应位置的矩阵:
# Top-left
| 0 1 |
[ 0 9] x | | = [18 27]
| 2 3 |
# Top-right
| 4 5 |
[ 1 10] x | | = [64 75]
| 6 7 |
# Bottom-left
| 8 9 |
[ 3 12] x | | = [144 159]
| 10 11 |
# Bottom-right
| 12 13 |
[ 4 13] x | | = [230 247]
| 14 15 |
然后,您只需添加所有结果向量:
[18 27] + [64 75] + [144 159] + [230 247] = [456 508]
其余的输出以相同的方式计算,例如,[512 571]
将通过将过滤器应用于下一个图像窗口来计算输出:
| [ 1 10] [ 2 11] |
| [ 4 13] [ 5 14] |
等等。
推荐阅读
- c# - 为不同的难度创造多个高分
- c# - 有没有办法在 JSON 文件的值中获取键名
- bash - 使用 sed 更改 URL 中的斜线方向
- struct - 错误:TSortedMap 以自定义结构为键,重载 operator<
- html - Flexbox / 由于嵌套元素而扩展高度
- detox - 有没有办法选择排毒测试的顺序?
- apache-flink - 在 Flink 集群上运行的 Apache Beam 管道失败
- r - R:如何读取单元格引用以及从 Excel 读取数据
- c - 如何在 C 中使用 strtol 将数字从命令行转换为 inter?
- javascript - 是否可以从该函数采用的回调内部退出(本机)函数?