tensorflow - 如何在 TF2.0 的 `tf.Module` 中初始化变量
问题描述
在Tensorflow2.0中,我发现我可以通过以下方式初始化模型中的变量
class MyModel(tf.keras.Model):
def __init__(self, *args, kwargs**):
""" some definition here """
self(tf.keras.Input(shape=(3,)))
def call(self, x):
""" some implementation """
但我不能做类似的事情
class MyModel(tf.keras.Model):
def __init__(self, *args, kwargs**):
""" some definition here """
self.step(tf.keras.Input(shape=(3,)))
def step(self, x):
""" some implementation """
这将给出错误
我想做第二个的原因是我尝试继承MyModel
自tf.Module
,它没有__call__
可用的---即使我定义了一个,也会出现同样的错误。我想知道是否有一种方法可以tf.Module
像在第一个代码块中那样初始化继承自的类中的变量?
解决方案
遗憾的是,Keras 功能/符号 API 仅与 Keras 兼容(例如 compile+fit)。
您可能可以tf.zeros
改用(例如self(tf.zeros(input_shape))
),尽管这可能会产生不良的副作用(例如影响您的批量标准统计信息)。
如果您想要一个强大的解决方案,您可能需要考虑使用snt.build(self, input_shape)
[0] ,它是Sonnet 2中的一个实用函数(一个包含一堆 commontf.Module
的库)。
[0] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/build.py#L50
推荐阅读
- holoviews - 使用 Dimension value_format 回调时,holoviews 找不到 flexx
- django - 使用 celery 任务将文件上传到 s3
- python - 用两个分隔符拆分 python 列表
- php - 具有已定义类型的 php 7 函数返回与类型相关的错误
- android - Flutter - 无法将文件复制到新路径
- apache-spark - Pyspark DataFrame - 将具有 n 个元素数组“key=value”的列转换为 n 个新列
- c# - 高效的半字节数据搜索查询算法?
- android - Android - 如何通过用户电子邮件进行提醒通知?
- url - apache2从域重定向到子文件夹
- javascript - 如何遍历 JSON 响应文本并使用 PHP 获取所有值