首页 > 解决方案 > 仅在 Keras 中训练多输出模型时显示总损失

问题描述

我正在通过 Keras 的功能 API 实现一个自动编码器模型。我的模型是多输出的,结果是在每个输出上评估一个损失函数。在训练期间,这些损失的加权和被最小化:

losses = [jsd for j in range(m)]  # JSD loss function for each output
autoencoder = Model(inputs, decodes)
sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
autoencoder.compile(optimizer=sgd, loss=losses, loss_weights=[1 for k in range(m)]) # each output has the same priority

然后我将我的模型拟合到训练数据并在测试数据上对其进行评估:

history = autoencoder.fit(train_corr, train_attr_corr, epochs=50, batch_size=10, shuffle=True, verbose=2,
                          validation_data=(test_corr, test_attr_GT))

因为verbose=2,训练和验证损失会在每个 epoch 结束时显示在控制台中。但是,因为模型是多输出的,所以会显示所有的“子损失”。例如:

Epoch 1/50
 - 3s - loss: 0.3356 - dense_4_loss: 0.0647 - dense_5_loss: 0.0436 - dense_6_loss: 0.0391 - dense_7_loss: 0.0378 - dense_8_loss: 0.0250 - dense_9_loss: 0.0362 - val_loss: 0.1067 - val_dense_4_loss: 0.0101 - val_dense_5_loss: 0.0042 - val_dense_6_loss: 0.0031 - val_dense_7_loss: 0.0036 - val_dense_8_loss: 0.0041 - val_dense_9_loss: 0.0066

问题:是否可以只显示每个 epoch的总训练损失 ( loss) 和总验证损失?val_loss

编辑:在上面的例子中,我只想显示loss: 0.3356and val_loss: 0.1067

标签: pythontensorflowmachine-learningkerasneural-network

解决方案


在 Keras model.fit 函数中使用默认的详细程度选项是不可能的。但是,您可以使用自定义回调来实现此目的。使用 禁用拟合函数中的详细程度verbosity=0。定义以下回调函数,该函数在纪元开始和结束时使用修改后的结果覆盖默认回调。

class PrinterCallback(tf.keras.callbacks.Callback):

    # def on_train_batch_begin(self, batch, logs=None):
    #     # Do something on begin of training batch

    def on_epoch_end(self, epoch, logs=None):
        print('EPOCH: {}, Train Loss: {}, Val Loss: {}'.format(epoch,
                                                               logs['loss'],
                                                               logs['val_loss']))

    def on_epoch_begin(self, epoch, logs=None):
        print('-'*50)
        print('STARTING EPOCH: {}'.format(epoch))

    # def on_train_batch_end(self, batch, logs=None):
    #     # Do something on end of training batch
    #

在调用 model.fit 时,将此回调用作callback=[PrinterCallback()]. 还有其他功能也可以在这里操作。例如,您可以在 train begin 上做什么等(代码中显示了几个示例)。随意修改所需值的打印方式,例如,控制小数位。

有关 Keras 回调的详细信息可在此处获得,您还可以查看其他回调的源代码以实现您自己的回调。

希望有帮助!


推荐阅读