python - Keras 无法在回调中计算图形节点
问题描述
为了说明这个问题,请考虑以下模型的(完全人为的)示例:
import numpy as np
from tensorflow.keras.utils import Sequence
from tensorflow.keras.callbacks import Callback
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model
class RandomSeq(Sequence):
def __len__(self):
return 5
def __getitem__(self, idx):
return (np.array([np.arange(39).reshape((39,)) for i in range(100)]),
np.array([-np.arange(39).reshape((39,)) for i in range(100)]))
class Foo(Callback):
def __init__(self, d):
super(Foo, self).__init__()
self._d = d
def on_epoch_end(self, epoch, logs=None):
print(epoch)
print(K.eval(self._d))
x = Input(shape=(39,), dtype='float32', name='input')
y_pred = Dense(39)(x)
y_another = x * 2
m = Model(inputs=x, outputs=y_pred)
m.compile(optimizer='sgd', loss='mse')
seq = RandomSeq()
m.fit_generator(seq, epochs=5, callbacks=[Foo(y_another)])
RandomSeq
只是一个返回x
和y
批处理的序列。Foo
是一个回调,它将尝试d
在一个时期结束时评估附加的数量。对我来说,如果我选择d
是y_pred
or y_another
,那么 Keras 会抱怨占位符x
( input
) 没有被输入。
您必须使用 dtype float 和 shape [?,39] 为占位符张量“input_1”提供一个值
这是预期的行为吗?如果是这样,是否有另一种方法来计算 Keras 回调中的节点?请注意,如果没有计算上述图形节点的回调,则该示例可以正常工作。
解决方案
这不是正确的做法。运行 K.eval(y_another) 您要求 keras 后端评估一个 Input 对象(它只是您想要输入网络的数据的占位符)而不用任何数据输入它,这就是您的错误的原因。
因此,假设您要计算网络的输出,给定一个新输入,该输入是一个随机序列乘以 2(对吗?),那么您应该修改on_epoch_end(self, epoch, logs=None)
回调中的主体,如下所示:
x, _ = next(RandomSeq())
print(self.model.predict(x*2))
推荐阅读
- spring - spring security - jwt 请求过滤器验证和验证会话,但主体为空
- java - 如何使用 Gradle 设置 JPMS 模块的 ModuleMainClass 属性?
- bash - 在第一次出现字符之前重命名文件名
- android - 为导航充气时由 android.view.InflateException 引起的运行时异常
- python - 如何在 Python 中不使用 math.log 计算对数
- docker - Docker 拒绝在 443 上卷曲
- c - 如何在C中连接vscode中的源文件和头文件
- php - 对数字数组进行排序的函数
- sql - SQL - 如果月份中的小时存在,则为“1”,如果不存在,则为“0”
- java - 将服务帐户与 Dataflow 一起使用 - 获取 storage.objects.get 访问错误