首页 > 解决方案 > Tensorflow 覆盖自定义层中的范围

问题描述

我正在尝试在 tensorflow 中实现一个嘈杂的线性层,继承自 tf.keras.layers.Layer 。除了重用变量之外,一切都很好。这似乎源于范围界定的一些问题:每当我使用超类中的 add_weight 函数并且已经存在具有相同名称的权重时,它似乎会忽略范围中给定的重用标志并创建一个新变量。有趣的是,在类似的情况下,它并没有像往常一样在变量名末尾添加 1,而是在作用域名称中添加 1。

import tensorflow as tf

class NoisyDense(tf.keras.layers.Layer):
    def __init__(self,output_dim):
        self.output_dim=output_dim

        super(NoisyDense, self).__init__()

    def build(self, input_shape):
        self.input_dim = input_shape.as_list()[1]
        self.noisy_kernel = self.add_weight(name='noisy_kernel',shape=  (self.input_dim,self.output_dim))

def noisydense(inputs, units):

    layer = NoisyDense(units)

    return layer.apply(inputs)

inputs = tf.placeholder(tf.float32, shape=(1, 10),name="inputs")



scope="scope"
with tf.variable_scope(scope):
    inputs3 = noisydense(inputs,
           1)
    my_variable = tf.get_variable("my_variable", [1, 2, 3],trainable=True)


with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
    inputs2 = noisydense(inputs,
           1)
    my_variable = tf.get_variable("my_variable", [1, 2, 3],trainable=True)

tvars = tf.trainable_variables()



init=tf.global_variables_initializer()    
with tf.Session() as sess:
    sess.run(init)
    tvars_vals = sess.run(tvars)

for var, val in zip(tvars, tvars_vals):
    print(var.name, val)

这导致变量

  scope/noisy_dense/noisy_kernel:0
   scope_1/noisy_dense/noisy_kernel:0
   scope/my_variable:0

正在打印。我希望它重用嘈杂的内核,而不是创建第二个内核,就像对 my_variable 所做的那样。

标签: tensorflowscope

解决方案


推荐阅读