tensorflow - Tensorflow:如何将图像与自定义常量过滤器进行卷积
问题描述
我有 3 个 5x5 过滤器,我想在灰度图像(形状 [nx,ny,1])输入上进行卷积。我已经预设了这些 5x5 过滤器中的每一个都需要的硬编码值,我不希望它们被我的模型“学习”,而只是一个恒定的操作。
我如何实现这一目标?
我正在研究使用 tf.nn.conv2d() 并且它说它的过滤器需要是形状 [height, width, input, output] 所以我尝试使用 tf.constant() 为我的形状过滤器创建一个张量[5,5,1,3](因此 3 个形状为 5x5 的滤波器应用于具有 1 个通道的输入)但 tf.constant() 的结果看起来不正确。结果是这样的:
[[[[ -5 7 -12]]
[[ 21 0 2]]
[[ -6 9 -6]]
[[ 2 -2 8]]
[[-6 4 -1]]]
[[[ 2 -6 8]]
[[ -6 2 -1]]
[[ 2 -2 2]]
[[ -1 1 5]]
[[ 4 3 2]]]
...etc
它看起来不像 3 个 5x5 过滤器的形状。
如果我使用形状为 [1,3,5,5] 的 tf.constant() 我得到这个:
[[[[ -5 7 -12 21 0]
[ 2 -6 9 -6 2]
[ -2 8 -6 4 -1]
[ 2 -6 8 -6 2]
[ -1 2 -2 2 -1]]
[[ 1 5 4 3 2]
[ 4 0 -2 0 4]
[ 2 -1 7 -3 5]
[ -1 0 -1 0 -1]
[ 5 0 9 0 5]]
...etc
看起来确实像 5x5 过滤器,但它不是 tf.nn.conv2d() 采用的正确形状
所以我对这种不匹配感到困惑,不知道该怎么做才是正确的。
解决方案
最好不要担心过滤器的外观。只需跟踪形状以确保它们有意义。
这是一个将 2 个 Sobel 过滤器应用于图像的示例:
from skimage import data
img = np.expand_dims(data.camera(), -1)
img = np.expand_dims(img, 0) # shape: (1, 512, 512, 1)
sobel_x = np.array([[-0.25, -0.2 , 0. , 0.2 , 0.25],
[-0.4 , -0.5 , 0. , 0.5 , 0.4 ],
[-0.5 , -1. , 0. , 1. , 0.5 ],
[-0.4 , -0.5 , 0. , 0.5 , 0.4 ],
[-0.25, -0.2 , 0. , 0.2 , 0.25]])
sobel_y = np.array([[-0.25, -0.4 , -0.5 , -0.4 , -0.25],
[-0.2 , -0.5 , -1. , -0.5 , -0.2 ],
[ 0. , 0. , 0. , 0. , 0. ],
[ 0.2 , 0.5 , 1. , 0.5 , 0.2 ],
[ 0.25, 0.4 , 0.5 , 0.4 , 0.25]])
filters = np.concatenate([[sobel_x], [sobel_y]]) # shape: (2, 5, 5)
filters = np.expand_dims(filters, -1) # shape: (2, 5, 5, 1)
filters = filters.transpose(1, 2, 3, 0) # shape: (5, 5, 1, 2)
# Convolve image
ans = tf.nn.conv2d((img / 255.0).astype('float32'),
filters,
strides=[1, 1, 1, 1],
padding='SAME')
with tf.Session() as sess:
ans_np = sess.run(ans) # shape: (1, 512, 512, 2)
filtered1 = ans_np[0, ..., 0]
filtered2 = ans_np[0, ..., 1]
图像与 2 个过滤器正确卷积,生成的图像如下所示:
plt.matshow(filtered1)
plt.show()
plt.matshow(filtered2)
plt.show()
推荐阅读
- vb.net - 在 Visual Basic 窗体中单击时隐藏动态创建的按钮
- reactjs - Expo - React Native:Firebase 电话号码身份验证失败
- spring - 为什么在启动 Spring Boot 应用程序之前无法登录?
- reactjs - “vector.project”到相机的标准化屏幕位置?(反应三纤维)
- c++ - 如何在运行时使用模型切换 SQL 数据库
- xamarin.forms - Xamarin Forms - 内容页面中的 webkit
- c - 如何用C语言写日记?
- python - TypeError:导入 sklearn 时需要一个整数(获取类型字节)
- c++ - 如何在 MacOS (OS X) 上将 GNU make 4.3 设为默认
- android - 如何实现具有弯曲角的可绘制形状