首页 > 解决方案 > 如何在 TF2.0 的 `tf.Module` 中初始化变量

问题描述

在Tensorflow2.0中,我发现我可以通过以下方式初始化模型中的变量

class MyModel(tf.keras.Model):
    def __init__(self, *args, kwargs**):
        """ some definition here """
        self(tf.keras.Input(shape=(3,)))

    def call(self, x):
        """ some implementation """        

但我不能做类似的事情

class MyModel(tf.keras.Model):
    def __init__(self, *args, kwargs**):
        """ some definition here """
        self.step(tf.keras.Input(shape=(3,)))

    def step(self, x):
        """ some implementation """        

这将给出错误 在此处输入图像描述 我想做第二个的原因是我尝试继承MyModeltf.Module,它没有__call__可用的---即使我定义了一个,也会出现同样的错误。我想知道是否有一种方法可以tf.Module像在第一个代码块中那样初始化继承自的类中的变量?

标签: tensorflowtensorflow2.0

解决方案


遗憾的是,Keras 功能/符号 API 仅与 Keras 兼容(例如 compile+fit)。

您可能可以tf.zeros改用(例如self(tf.zeros(input_shape))),尽管这可能会产生不良的副作用(例如影响您的批量标准统计信息)。

如果您想要一个强大的解决方案,您可能需要考虑使用snt.build(self, input_shape)[0] ,它是Sonnet 2中的一个实用函数(一个包含一堆 commontf.Module的库)。

[0] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/build.py#L50


推荐阅读