python - 深度上的 TensorFlow 平均值
问题描述
我正在使用向量序列作为张量流中 NN 的输入数据,我想在输入的深度上执行平均池化。
我尝试使用以下 lambda 层:
depth_pool = keras.layers.Lambda(
lambda X: tf.nn.avg_pool1d(X,
ksize=(1, 1, 3),
strides=(1, 1, 3),
padding="VALID"))
但是,我收到错误消息:
UnimplementedError:尚不支持非空间池化。
有没有办法达到预期的结果?
非常感谢您的帮助
解决方案
如果您的输入具有这些维度:(None, timestamps, features)
您可以简单地将深度与其他维度置换,应用标准池化,然后置换回原始维度。
举个例子......如果您的网络接受输入白色形状(None, 20, 99)
,您可以简单地执行以下操作来获得深度池:
inp = Input((20,99))
depth_pool = Permute((2,1))(inp)
depth_pool = AveragePooling1D(3)(depth_pool)
depth_pool = Permute((2,1))(depth_pool)
m = Model(inp, depth_pool)
m.summary()
摘要:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_4 (InputLayer) [(None, 20, 99)] 0
_________________________________________________________________
permute_4 (Permute) (None, 99, 20) 0
_________________________________________________________________
average_pooling1d_3 (Average (None, 33, 20) 0
_________________________________________________________________
permute_5 (Permute) (None, 20, 33) 0
=================================================================
输出有形状(None, 20, 33)
如果您的输入具有这些维度:(None, features, timestamps)
您可以简单地data_format='channels_first'
在您的图层中设置
inp = Input((20,99))
depth_pool = AveragePooling1D(3, data_format='channels_first')(inp)
m = Model(inp, depth_pool)
m.summary()
摘要:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_9 (InputLayer) [(None, 20, 99)] 0
_________________________________________________________________
average_pooling1d_7 (Average (None, 20, 33) 0
=================================================================
输出有形状(None, 20, 33)
推荐阅读
- python - Flask-WTForm FileField 返回 None 而不是上传的文件
- external - 找出外部命令的可运行性
- highcharts - 如何在 Highcharts v.4 中引用系列?
- c# - 在动态的一个操作中创建相关实体时,有没有办法将联系人“parentcustomerid”链接到帐户?
- java - Java SOAP WS 客户端从命令行失败,在 Eclipse 中工作(令人震惊,对吗?)
- git - 重写历史记录后,仍然可以通过来自其他 repos 的引用看到旧的 git 提交
- javascript - 来自前端的 Javascript 时区
- flutter - Dart Riverpod:未定义的类'WidgetRef'
- r - googledrive::drive_upload() 中的路径参数不起作用
- github - 如何使用 GitHub Actions 下载存档的 GitHub 发布源代码