首页 > 解决方案 > 为什么keras使用“call”而不是__call__?

问题描述

我喜欢( https://www.tensorflow.org/tutorials/eager/custom_layers)中的以下代码

class MyDenseLayer(tf.keras.layers.Layer):
  def __init__(self, num_outputs):
    super(MyDenseLayer, self).__init__()
    self.num_outputs = num_outputs

  def build(self, input_shape):
    self.kernel = self.add_variable("kernel",
                                shape=[int(input_shape[-1]),
                                       self.num_outputs])

  def call(self, input):
    return tf.matmul(input, self.kernel)

最后两行是调用方法,而不像通常 的带有两个下划线的python类方法调用。它们之间有什么区别吗?

标签: tensorflowkeras

解决方案


以下答案基于https://tf.wiki/zh/basic/models.html

ClassA基本上在 Python 中,当您使用调用类中的实例时ClassA(),它等价于ClassA.__call__(). 所以在这种情况下使用__call__()而不是似乎是合理的,对吧?call()

但是,我们使用的原因call()是,当tf.keras调用模型或层时,它有自己的内部操作,这对于保持其内部结构至关重要。结果,它公开了一种call()客户重载的方法。__call()__调用call()以及一些内部操作,所以当我们重新加载call()继承自tf.keras.Modelortf.keras.Layer时,我们可以在保持tf.keras内部结构的同时调用我们的自定义代码。

例如,根据我的经验,如果您的输入是一个 numpy 数组而不是张量,如果您在其中编写客户代码,则无需手动转换它,call()但如果您覆盖__call__(),这将是一个问题,因为某些内部操作不是叫。


推荐阅读