python - 了解 tf.nn.depthwise_conv2d
问题描述
来自 https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d
给定一个 4D 输入张量('NHWC' 或 'NCHW' 数据格式)和一个形状为 [filter_height, filter_width, in_channels, channel_multiplier] 的滤波器张量,其中包含深度为 1 的 in_channels 卷积滤波器,depthwise_conv2d 对每个输入通道应用不同的滤波器(扩展从 1 个通道到每个通道的 channel_multiplier 通道),然后将结果连接在一起。输出有 in_channels * channel_multiplier 通道
- “从 1 个频道扩展到每个频道的 channel_multiplier 频道”是什么意思?
- 是否有可能有 out_channels < in_channels?
- 是否可以将输入张量划分为 Pytorch https://pytorch.org/docs/stable/nn.html#conv2d中的组?
例子:
import tensorflow as tf
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
np.random.seed(2020)
print('tf.__version__', tf.__version__)
def get_data_batch():
bs = 2
h = 3
w = 3
c = 4
x_np = np.random.rand(bs, h, w, c)
x_np = x_np.astype(np.float32)
print('x_np.shape', x_np.shape)
return x_np
def run_conv_dw():
print('='*60)
x_np = get_data_batch()
in_channels = x_np.shape[-1]
kernel_size = 3
channel_multiplier = 1
with tf.Session() as sess:
x_tf = tf.convert_to_tensor(x_np)
filter = tf.get_variable('w1', [kernel_size, kernel_size, in_channels, channel_multiplier],
initializer=tf.contrib.layers.xavier_initializer())
z_tf = tf.nn.depthwise_conv2d(x_tf, filter=filter, strides=[1, 1, 1, 1], padding='SAME')
sess.run(tf.global_variables_initializer())
z_np = sess.run(fetches=[z_tf], feed_dict={x_tf: x_np})[0]
print('z_np.shape', z_np.shape)
if '__main__' == __name__:
run_conv_dw()
通道乘数不能为浮点数:
如果channel_multiplier = 1
:
x_np.shape (2, 3, 3, 4)
z_np.shape (2, 3, 3, 4)
如果channel_multiplier = 2
:
x_np.shape (2, 3, 3, 4)
z_np.shape (2, 3, 3, 8)
解决方案
在 pytorch 方面:
- 每组总是一个输入通道,每组“channel_multiplier”输出通道;
- 不是一步到位;
- 见1
我看到了一种模拟每组多个输入通道的方法。对于两个, do depthwise_conv2d
,然后将结果张量作为一副纸牌分成两半,然后按元素求和获得的一半(在 relu 等之前)。请注意,输入通道号i
将被归为i+inputs/2
一组。
编辑:上面的技巧对小团体很有用,对于大团体来说,只需将输入张量拆分为 N 个部分,其中 N 是组数,conv2d
独立制作每个部分,然后连接结果。
推荐阅读
- powershell - Powershell:在开始时添加换行符
- flutter - 使用 png 或 jpeg 在 Flutter 中更改应用程序徽标?
- access-token - Google 操作构建器实现 webhook 未将用户访问令牌传递到后端
- ios - 构建设置下的特定目标缺少“架构”部分。在模拟器上工作正常,导致设备出现问题
- laravel - 如何设置重写规则以使用 apache 在同一服务器中转到前端和后端
- typescript - Typescript - 从 yaml 读取变量并用它们替换另一个文件中的标记
- typescript - 如何键入传入函数的窄数组类型?
- flutter - 推送扩展的 Widget Flutter
- python - 如何在 python 中检查 Outlook ActiveWindow 的 TypeName
- sql - 根据另一列中的值何时更改添加新列