python - 尝试使用 tensorflow 自定义回调获取中间层预测时出现“层未连接,没有输入返回”错误
问题描述
我正在尝试使用自定义回调在训练期间访问模型中间层的预测。以下实际代码的精简版本演示了该问题。
import tensorflow as tf
import numpy as np
class Model(tf.keras.Model):
def __init__(self, input_shape=None, name="cus_model", **kwargs):
super(Model, self).__init__(name=name, **kwargs)
def build(self, input_shape):
self.dense1 = tf.keras.layers.Dense(input_shape=input_shape, units=32)
def call(self, input_tensor):
return self.dense1(input_tensor)
class CustomCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
get_output = tf.keras.backend.function(
inputs = self.model.layers[0].input,
outputs = self.model.layers[0].output
)
print("Layer output: ",get_output.outputs)
X = np.ones((8,16))
y = np.sum(X, axis=1)
model = Model()
model.compile(optimizer='adam',loss='mean_squared_error', metrics='accuracy')
model.fit(X,y, epochs=8, callbacks=[CustomCallback()])
回调是按照此答案中的建议编写的。收到以下错误:
<ipython-input-3-635fd53dbffc> in on_epoch_end(self, epoch, logs)
12 def on_epoch_end(self, epoch, logs=None):
13 get_output = tf.keras.backend.function(
---> 14 inputs = self.model.layers[0].input,
15 outputs = self.model.layers[0].output
16 )
.
.
AttributeError: Layer dense is not connected, no input to return.
这是什么原因造成的?如何解决?
解决方案
我运行这个没有问题:
import tensorflow as tf
import numpy as np
X = np.ones((8,16))
y = np.sum(X, axis=1)
class CustomCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
get_output = tf.keras.backend.function(
inputs = self.model.layers[0].input,
outputs = self.model.layers[1].output # return output of first dense
)
print("\nLayer output: ", get_output(X))
inp = tf.keras.layers.Input((16,))
dense1 = tf.keras.layers.Dense(units=32)(inp)
dense2 = tf.keras.layers.Dense(units=20)(dense1)
model = tf.keras.Model(inp, dense2)
model.compile(optimizer='adam',loss='mean_squared_error', metrics='accuracy')
model.fit(X,y, epochs=8, callbacks=[CustomCallback()])
推荐阅读
- mongodb - 如何解决“MongoError: $where is not allowed in this atlas tier”?
- android - 如何从 HEIF 格式获取旋转数据?
- python - setuptools的“编程语言”分类器中版本的目的是什么?
- saucelabs - 适用于 Android/IOS 酱实验室的 Webdriver.io
- mongodb - 创建一个可以访问特定数据库上的 listCollections 的 mongo 用户
- webpack - 自定义 Webpack 插件:访问转换代码的钩子
- php - 避免在 SSH 命令中区分大小写
- rest - 如何在加特林中使用单个用户对多个请求进行负载测试
- ios - 关于在 Swift 中将 Int 类型的值分配给 String 类型
- keras - Keras 通过向量查找索引 - 反向嵌入层