tensorflow - 为什么keras使用“call”而不是__call__?
问题描述
我喜欢( https://www.tensorflow.org/tutorials/eager/custom_layers)中的以下代码
class MyDenseLayer(tf.keras.layers.Layer):
def __init__(self, num_outputs):
super(MyDenseLayer, self).__init__()
self.num_outputs = num_outputs
def build(self, input_shape):
self.kernel = self.add_variable("kernel",
shape=[int(input_shape[-1]),
self.num_outputs])
def call(self, input):
return tf.matmul(input, self.kernel)
最后两行是调用方法,而不像通常 的带有两个下划线的python类方法调用。它们之间有什么区别吗?
解决方案
以下答案基于https://tf.wiki/zh/basic/models.html。
ClassA
基本上在 Python 中,当您使用调用类中的实例时ClassA()
,它等价于ClassA.__call__()
. 所以在这种情况下使用__call__()
而不是似乎是合理的,对吧?call()
但是,我们使用的原因call()
是,当tf.keras
调用模型或层时,它有自己的内部操作,这对于保持其内部结构至关重要。结果,它公开了一种call()
客户重载的方法。__call()__
调用call()
以及一些内部操作,所以当我们重新加载call()
继承自tf.keras.Model
ortf.keras.Layer
时,我们可以在保持tf.keras
内部结构的同时调用我们的自定义代码。
例如,根据我的经验,如果您的输入是一个 numpy 数组而不是张量,如果您在其中编写客户代码,则无需手动转换它,call()
但如果您覆盖__call__()
,这将是一个问题,因为某些内部操作不是叫。
推荐阅读
- android - setAlarmClock() 反复触发警报 [问题原来是由于 Spinner 造成的]
- laravel - SQLSTATE [23000]:违反完整性约束:1048 列“分数”不能为空
- python - 尽管已将其全球化,但在赋值之前引用局部变量时出错
- javascript - 如何在关联数组中合并重复项?
- sql - 根据另一个表中的值在一行中生成一个随机值
- c++ - C++ 中的内存模型和单例
- mysql - 如何编写 SQL 查询来计算三个连续值的平均值?
- java - 将密钥发送到您无法检查的文本框(使用 Java)
- django - 如何在 Docker 容器上运行 Django
- rstudio - 使用面板数据进行固定效应回归 - 输出中不包括虚拟变量