tensorflow - 如何扩展 GlobalAveragePooling2D() 的输出以适合 BiSeNet?
问题描述
我正在尝试构建“ https://github.com/Blaizzy/BiSeNet-Implementation ”图中所示的 BiseNet。
当我想使用 Keras(tf-backend) 中的 GlobalAveragePooling2D() 完成图(b) 中的 Attention Refined Module 时,发现 GlobalAveragePooling2D() 的输出形状不适合下一次卷积。
我在 github 中检查了 BiSeNet 代码的许多实现,但是,它们中的大多数都使用 AveragePooling2D(size=(1,1)) 代替。但是AveragePooling2D(size=(1,1)) 完全没有意义。
所以我定义了一个 lambda 层来做我想做的事情(所选代码如下所示)。lambda 层有效,但看起来很丑:
def samesize_globalAveragePooling2D(inputtensor):
# inputtensor shape:(?, 28,28,32)
x = GlobalAveragePooling2D()(inputtensor) # x shape:(?, 32)
divide = tf.divide(inputtensor, inputtensor) # divide shape:(?, 28,28,32)
x2 = x * divide # x2 shape:(?, 28,28,32)
global_pool = Lambda(function=samesize_globalAveragePooling2D)(conv_0)
希望得到建议,使这个 lambda 更加优雅。
谢谢!
解决方案
这可以使用 tf.reduce_mean 上的 lambda 层来完成。
tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=[1, 2], keep_dims=True))
推荐阅读
- html - 我的项目没有与 flex 对齐/没有移动响应
- ios - iOS 14 通过 Facebook SDK 获得用户同意
- python - Python requests.get() 循环不返回任何内容
- c# - 我应该在 FireTv 应用程序中为视频广告使用什么 Amazon SDK?
- javascript - 正确导入节点模块
- mysql - SQL:如何在嵌套查询中搜索两列?
- ssis - 如何从本地执行 SSIS 包,但 (.csv, .xls ...) 等文件在 VM 上
- android - Ionic Cordova Android:输入调度超时错误
- shell - 如何使用新的 buildpacks.io 框架部署 heroku nginx buildpack?
- visual-studio-code - 找到了一些VSCode:如何确定光标在编辑器中的列位置?