python - keras 上的 Tensorboard 回调在训练多个网络时会给出 InvalidArgumentError
问题描述
我有一种方法train_model
可以获取 keras 模型对象作为输入并对其进行训练。我的代码中的其他地方有一个循环,它在每次迭代中创建一个新模型并将其传递给这个方法。如果我没有通过 TensorBoard 回调,我的代码可以正常工作。但是,当我通过 TensorBoard 回调时,第一个网络得到了训练,但是对于第二个网络,我得到了这个错误:
tensorflow.python.framework.errors_impl.InvalidArgumentError:您必须为占位符张量“dense_7_target”提供一个值,其 dtype 为 float 和 shape [?,?] [[{{node dense_7_target}} = Placeholderdtype=DT_FLOAT, shape=[?,? ], _device="/job:localhost/replica:0/task:0/device:CPU:0"]]
调用.fit
方法后。(我正在构建的网络是 5 层)
还有更多意想不到的行为:当我第二次运行此代码时,第一个网络不需要训练(因为我已经保存了模型并且它只会加载它)并且第二个网络得到了没有错误的训练,但是对于第三个网络,我得到了同样的错误。
在这种情况下,当我检查 TensorBoard 图表时,我看到第一个网络已正确创建,但第二个网络的层数是应有的两倍(好像第一个模型已先加载,然后第二个网络已构建在上面)。这是我的train_model
方法:
def train_model(model, data, dataname, MODEL_DIR, LOG_DIR, BS, EP, callbacks):
X_train, Y_train, X_test, Y_test = data
if not os.path.exists(MODEL_DIR):
os.makedirs(MODEL_DIR)
model_callback = tf.keras.callbacks.ModelCheckpoint(MODEL_DIR + dataname + '/',
monitor='mse',
verbose=1,
save_weights_only=True)
tb_callback = tf.keras.callbacks.TensorBoard(log_dir=LOG_DIR + dataname,histogram_freq=10)
callbacks += [model_callback, tb_callback]
if not os.path.exists(MODEL_DIR + dataname + '/'):
model.fit(X_train, Y_train, batch_size=BS, epochs=EP, verbose=0, callbacks=callbacks, validation_split=0.2)
validation_split=0.2)
else:
model.load_weights(MODEL_DIR + dataname + '/')
return model
我已经尽我所能,我真的不知道我的代码出了什么问题。
任何帮助表示赞赏。提前致谢。
解决方案
推荐阅读
- c - 我的 C 代码没有捕捉到“最终客户”结束
- c# - 当 where 子句包含双引号特殊字符时,Linq to SQL 失败
- android - 在 Android 中使用 UPI APP 向用户发送付款
- python - Python - 在 for 循环中附加到二维数组会导致以前的条目被覆盖
- android - 未报告 Firebase 致命崩溃
- javascript - 将选定的值 Id 从父级传递给子级 Kendo DropdownList
- c++ - to_string(int) 的意外输出
- python - 在 python 中使用 selenium 查找单选按钮
- sql - SQL 语法异常。在休眠中映射和保存列的问题
- google-bigquery - 如何使用 BigQuery 过滤多个值