python - 加权平均:自定义层权重在 TensorFlow 2.2.0 中没有变化
问题描述
我正在尝试在 TensorFlow 中实现两个张量之间的加权平均值,其中可以自动学习权重。按照这里关于如何为 keras 模型设计自定义层的建议,我的尝试如下:
class WeightedAverage(tf.keras.layers.Layer):
def __init__(self):
super(WeightedAverage, self).__init__()
init_value = tf.keras.initializers.Constant(value=0.5)
self.w = self.add_weight(name="weight",
initializer=init_value,
trainable=True)
def call(self, inputs):
return tf.keras.layers.average([inputs[0] * self.w,
inputs[1] * (1 - self.w)])
现在的问题是,在训练模型、保存并再次加载之后, 的值w
仍然是 0.5。参数是否有可能没有收到任何梯度更新?在打印我的模型的可训练变量时,会列出该参数,因此在调用model.fit
.
解决方案
这是在两个张量之间实现加权平均的可能性,其中可以自动学习权重。我还介绍了权重总和必须为 1 的约束。为了实现这一点,我们必须简单地对权重应用 softmax。在下面的虚拟示例中,我将此方法与两个完全连接的分支的输出相结合,但您可以在其他所有场景中对其进行管理
这里是自定义层:
class WeightedAverage(Layer):
def __init__(self):
super(WeightedAverage, self).__init__()
def build(self, input_shape):
self.W = self.add_weight(
shape=(1,1,len(input_shape)),
initializer='uniform',
dtype=tf.float32,
trainable=True)
def call(self, inputs):
# inputs is a list of tensor of shape [(n_batch, n_feat), ..., (n_batch, n_feat)]
# expand last dim of each input passed [(n_batch, n_feat, 1), ..., (n_batch, n_feat, 1)]
inputs = [tf.expand_dims(i, -1) for i in inputs]
inputs = Concatenate(axis=-1)(inputs) # (n_batch, n_feat, n_inputs)
weights = tf.nn.softmax(self.W, axis=-1) # (1,1,n_inputs)
# weights sum up to one on last dim
return tf.reduce_sum(weights*inputs, axis=-1) # (n_batch, n_feat)
这是回归问题中的完整示例:
inp1 = Input((100,))
inp2 = Input((100,))
x1 = Dense(32, activation='relu')(inp1)
x2 = Dense(32, activation='relu')(inp2)
W_Avg = WeightedAverage()([x1,x2])
out = Dense(1)(W_Avg)
m = Model([inp1,inp2], out)
m.compile('adam','mse')
n_sample = 1000
X1 = np.random.uniform(0,1, (n_sample,100))
X2 = np.random.uniform(0,1, (n_sample,100))
y = np.random.uniform(0,1, (n_sample,1))
m.fit([X1,X2], y, epochs=10)
最后,您还可以通过这种方式可视化权重的值:
tf.nn.softmax(m.get_weights()[-3]).numpy()
推荐阅读
- php - php中私有方法的继承和后期静态绑定
- dart - AngularDart 触摸和移动
- java - 无法在 android web 视图中从服务器下载文件
- matlab - 在MATLAB中为神经网络执行用户定义的卷积函数非常慢
- python - NLTK语言树遍历和提取名词短语(NP)
- xcode - 如何在 Xcode 中渲染具有半透明纹理的 3D 模型?
- c++ - 为什么我看不到字符串?
- docker - 错误:在 Jenkins 中获取远程 repo 'origin'
- overriding - Prestashop 1.7 覆盖 CmsController
- python-3.x - 无法在嵌套表中的 Python Selenium 中使用 CSS 选择器引用元素