首页 > 解决方案 > 如何为 LocallyConnected2D 层使用 WeightNormalization 包装器

问题描述

我正在尝试在这样的层tfa.layers.WeightNormalization周围使用包装器:tf.layers.LocallyConnected2D

from tensorflow_addons.layers import WeightNormalization
import tensorflow as tf

x = tf.ones((1, 32, 32, 3))
x = WeightNormalization(tf.keras.layers.LocallyConnected2D(3, 3))(x)

它给出了以下错误:

TypeError: 'NoneType' object is not callable

作为记录,这确实适用于Conv2D图层。知道如何让这个与LocallyConnected2D图层一起工作吗?

标签: pythontensorflowkerasnormalizationkeras-layer

解决方案


来自评论

from tensorflow_addons.layers import WeightNormalization
import tensorflow as tf

x = tf.ones((1, 32, 32, 3))
x = WeightNormalization(tf.keras.layers.LocallyConnected2D(3, 3), data_init=False)(x)

(转述自 Marco Cerliani)


推荐阅读