首页 > 解决方案 > 使用 tensorflow 签名和 tf.function 更新外部状态

问题描述

我正在创建一个图层,它是批量标准化的变体,并使用tf.function装饰器来加速它。但是,我收到一条错误消息,提示 autograph 不知道我是否要重用我正在尝试更新的变量。

class MyClass(tf.keras.layers.Layer):

    def build():
        self.foo = self.add_weight(...)

    @tf.function
    def call(inputs, training=None):
        lst = [K.moving_average_update(self.foo, .5, .999)]
        self.add_updates(lst)

这会引发类似于以下内容的错误:

ValueError: Variable my_class/foo/biased already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope?

告诉签名在函数外重用状态的惯用方法是什么?

标签: pythontensorflowtensorflow2.0

解决方案


为了重用self.foo您可以在 tf.variable_scope 块中定义的变量,设置参数reuse = tf.AUTO_REUSE. 例如:

def build(self, input_shape):

    with tf.variable_scope('scope', reuse=tf.AUTO_REUSE):
        self.foo = tf.Variable(initial_value=YOUR_INITIAL_VALUE, name='var')

    trainable = None  # True or False
    if trainable:
        self._trainable_weights.append(self.foo)
    else:
        self._non_trainable_weights.append(self.foo)

注意:未经测试。


推荐阅读