python - tensorflow中Conv1d的情况下如何实现平均池化?
问题描述
我想在 conv1d 中实现平均池化。但tf.nn.avg_pool
功能只能在 4 维张量上实现。那么我应该怎么做才能克服这个问题呢?
def avg_pool(conv_out):
return tf.nn.avg_pool(conv_out,ksize=[1,1,2,1],strides=[1,1,2,1],padding='SAME')
i = tf.constant([1, 0, 2, 3, 0, 1], dtype=tf.float32)
data = tf.reshape(i, [1, int(i.shape[0]), 1], name='data')
kernel = tf.Variable(tf.random_normal([2,1,1]))
conv_out = tf.nn.conv1d(data, kernel, 2, 'VALID')
pool_out = avg_pool(conv_out)
解决方案
一种选择是为您的数据添加一个额外的维度,然后将其删除:
def avg_pool(conv_out):
conv_out_2d = conv_out[:, tf.newaxis]
pool_out_2d = tf.nn.avg_pool(conv_out_2d,
ksize=[1, 1, 2, 1],
strides=[1, 1, 2, 1],
padding='SAME')
pool_out = pool_out_2d[:, 0]
return pool_out
另一种可能性是使用泛型tf.nn.pool
:
def avg_pool(conv_out):
return tf.nn.pool(conv_out, window_shape=[2], pooling_type='AVG', padding='SAME')
请注意,在这种情况下,我不包括步幅,因为默认值与您在示例中使用的值匹配,但您也可以根据需要对其进行修改。
推荐阅读
- javascript - 尝试通过网站上的模态制作“联系细节”按钮,但每个人都获得相同的“模态内容”
- c# - C# 设置更积极的垃圾回收
- vba - VBA 提取数据:在没有 Next 的情况下出现编译错误
- android - 将 ImageView 权重设置为屏幕中心的 50%
- c# - 计算布尔属性为真的实例数 - C#
- node.js - 或者条件在 nodejs bot 中不起作用
- python - 用 Python 编写文件的一部分
- javascript - 孩子没有触发父母功能
- python - 在 python 中插入 MySQL 数据库的问题
- java - Mock链调用涉及Java中的Stream