首页 > 解决方案 > keras 上的 Tensorboard 回调在训练多个网络时会给出 InvalidArgumentError

问题描述

我有一种方法train_model可以获取 keras 模型对象作为输入并对其进行训练。我的代码中的其他地方有一个循环,它在每次迭代中创建一个新模型并将其传递给这个方法。如果我没有通过 TensorBoard 回调,我的代码可以正常工作。但是,当我通过 TensorBoard 回调时,第一个网络得到了训练,但是对于第二个网络,我得到了这个错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError:您必须为占位符张量“dense_7_target”提供一个值,其 dtype 为 float 和 shape [?,?] [[{{node dense_7_target}} = Placeholderdtype=DT_FLOAT, shape=[?,? ], _device="/job:localhost/replica:0/task:0/device:CPU:0"]]

调用.fit方法后。(我正在构建的网络是 5 层)

还有更多意想不到的行为:当我第二次运行此代码时,第一个网络不需要训练(因为我已经保存了模型并且它只会加载它)并且第二个网络得到了没有错误的训练,但是对于第三个网络,我得到了同样的错误。

在这种情况下,当我检查 TensorBoard 图表时,我看到第一个网络已正确创建,但第二个网络的层数是应有的两倍(好像第一个模型已先加载,然后第二个网络已构建在上面)。这是我的train_model方法:

def train_model(model, data, dataname, MODEL_DIR, LOG_DIR, BS, EP, callbacks):
X_train, Y_train, X_test, Y_test = data
if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)
model_callback = tf.keras.callbacks.ModelCheckpoint(MODEL_DIR + dataname + '/',
                                                    monitor='mse',
                                                    verbose=1,
                                                    save_weights_only=True)
tb_callback = tf.keras.callbacks.TensorBoard(log_dir=LOG_DIR + dataname,histogram_freq=10)
callbacks += [model_callback, tb_callback]
if not os.path.exists(MODEL_DIR + dataname + '/'):
    model.fit(X_train, Y_train, batch_size=BS, epochs=EP, verbose=0, callbacks=callbacks, validation_split=0.2)
validation_split=0.2)
    else:
        model.load_weights(MODEL_DIR + dataname + '/')
    return model

我已经尽我所能,我真的不知道我的代码出了什么问题。

任何帮助表示赞赏。提前致谢。

标签: pythontensorflowmachine-learningkerastensorboard

解决方案


推荐阅读