tensorflow - 如何在 Keras 中使用“有状态”变量/张量创建自定义层?
问题描述
我想请你帮忙创建我的自定义层。我要做的实际上很简单:生成一个带有“有状态”变量的输出层,即每批更新其值的张量。
为了让一切更清楚,这里是我想做的一个片段:
def call(self, inputs)
c = self.constant
m = self.extra_constant
update = inputs*m + c
X_new = self.X_old + update
outputs = X_new
self.X_old = X_new
return outputs
这里的想法很简单:
X_old
在初始化为 0def__ init__(self, ...)
update
被计算为层的输入的函数- 计算层的输出(即
X_new
) - 的值
X_old
设置为等于,X_new
以便在下一批X_old
中不再等于零,而是等于X_new
上一批。
我发现可以K.update
完成这项工作,如示例所示:
X_new = K.update(self.X_old, self.X_old + update)
这里的问题是,如果我尝试将层的输出定义为:
outputs = X_new
return outputs
当我尝试 model.fit() 时,我会收到以下错误:
ValueError: An operation has `None` for gradient. Please make sure that all of your ops have
gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.
即使我强加了这个错误layer.trainable = False
并且我没有为图层定义任何偏差或权重,我仍然会遇到这个错误。另一方面,如果我只是这样做self.X_old = X_new
, 的值X_old
不会得到更新。
你们有解决方案来实现这个吗?我相信这不应该那么难,因为有状态的 RNN 也有“类似”的功能。
在此先感谢您的帮助!
解决方案
定义自定义层有时会变得令人困惑。您覆盖的某些方法将被调用一次,但它给您的印象是,就像许多其他 OO 库/框架一样,它们将被调用多次。
这就是我的意思:当您定义一个层并在模型中使用它时,您为覆盖call
方法编写的 Python 代码不会在向前或向后传递中直接调用。相反,它只在您调用时调用一次model.compile
。它将 python 代码编译为计算图,张量将在其中流动的图是训练和预测期间的计算。
这就是为什么如果你想通过添加一个print
语句来调试你的模型是行不通的;您需要使用tf.print
向图形添加打印命令。
您想要拥有的状态变量也是如此。而不是简单地分配old + update
给new
您,而是需要调用一个 Keras 函数,将该操作添加到图中。
请注意,张量是不可变的,因此您需要tf.Variable
在__init__
方法中定义状态。
所以我相信这段代码更像你正在寻找的:
class CustomLayer(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super(CustomLayer, self).__init__(**kwargs)
self.state = tf.Variable(tf.zeros((3,3), 'float32'))
self.constant = tf.constant([[1,1,1],[1,0,-1],[-1,0,1]], 'float32')
self.extra_constant = tf.constant([[1,1,1],[1,0,-1],[-1,0,1]], 'float32')
self.trainable = False
def call(self, X):
m = self.constant
c = self.extra_constant
outputs = self.state + tf.matmul(X, m) + c
tf.keras.backend.update(self.state, tf.reduce_sum(outputs, axis=0))
return outputs
推荐阅读
- c11 - MPSC环形缓冲区的DPDK实现
- c# - c# 中的 Process 类是非托管资源吗?
- python-3.x - 如何使用 ROS 2 在 python 节点中发布批量图像?
- python - Matplotlib 设置不同图形的相同框大小
- c - 为什么代码不检查函数中的条件
- reactjs - 在 React 中提交登录时,我在 (Localhost)URL 中获取我的密码和电子邮件作为查询字符串吗?
- reactjs - 在 Reactjs 中切换按钮的状态
- python - 在python中根据另一个子列表中的相应权重对子列表进行排序
- java - 无法解析 Cargo 跟踪器应用程序的依赖项和插件
- batch-file - 如果批处理中不存在变量