python - 仅在 Keras 中训练多输出模型时显示总损失
问题描述
我正在通过 Keras 的功能 API 实现一个自动编码器模型。我的模型是多输出的,结果是在每个输出上评估一个损失函数。在训练期间,这些损失的加权和被最小化:
losses = [jsd for j in range(m)] # JSD loss function for each output
autoencoder = Model(inputs, decodes)
sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
autoencoder.compile(optimizer=sgd, loss=losses, loss_weights=[1 for k in range(m)]) # each output has the same priority
然后我将我的模型拟合到训练数据并在测试数据上对其进行评估:
history = autoencoder.fit(train_corr, train_attr_corr, epochs=50, batch_size=10, shuffle=True, verbose=2,
validation_data=(test_corr, test_attr_GT))
因为verbose=2
,训练和验证损失会在每个 epoch 结束时显示在控制台中。但是,因为模型是多输出的,所以会显示所有的“子损失”。例如:
Epoch 1/50
- 3s - loss: 0.3356 - dense_4_loss: 0.0647 - dense_5_loss: 0.0436 - dense_6_loss: 0.0391 - dense_7_loss: 0.0378 - dense_8_loss: 0.0250 - dense_9_loss: 0.0362 - val_loss: 0.1067 - val_dense_4_loss: 0.0101 - val_dense_5_loss: 0.0042 - val_dense_6_loss: 0.0031 - val_dense_7_loss: 0.0036 - val_dense_8_loss: 0.0041 - val_dense_9_loss: 0.0066
问题:是否可以只显示每个 epoch的总训练损失 ( loss
) 和总验证损失?val_loss
编辑:在上面的例子中,我只想显示loss: 0.3356
and val_loss: 0.1067
。
解决方案
在 Keras model.fit 函数中使用默认的详细程度选项是不可能的。但是,您可以使用自定义回调来实现此目的。使用 禁用拟合函数中的详细程度verbosity=0
。定义以下回调函数,该函数在纪元开始和结束时使用修改后的结果覆盖默认回调。
class PrinterCallback(tf.keras.callbacks.Callback):
# def on_train_batch_begin(self, batch, logs=None):
# # Do something on begin of training batch
def on_epoch_end(self, epoch, logs=None):
print('EPOCH: {}, Train Loss: {}, Val Loss: {}'.format(epoch,
logs['loss'],
logs['val_loss']))
def on_epoch_begin(self, epoch, logs=None):
print('-'*50)
print('STARTING EPOCH: {}'.format(epoch))
# def on_train_batch_end(self, batch, logs=None):
# # Do something on end of training batch
#
在调用 model.fit 时,将此回调用作callback=[PrinterCallback()]
. 还有其他功能也可以在这里操作。例如,您可以在 train begin 上做什么等(代码中显示了几个示例)。随意修改所需值的打印方式,例如,控制小数位。
有关 Keras 回调的详细信息可在此处获得,您还可以查看其他回调的源代码以实现您自己的回调。
希望有帮助!
推荐阅读
- nuget - 使用四部分的版本号(ww.xx.yy.zz)有什么缺点吗?
- java - 将 SoapUI 属性值设置为今天 + 1 年
- kubernetes - Autoscaler 未扩展,使节点处于 NotReady 状态,Pod 处于未知状态
- jquery - 如何在其容器上扩展按钮功能
- github - GitHub 克隆到桌面使用 TortoiseGit 而不是 GitHub Desktop
- python - 用 ausent 值填充动态列表查询集并将其合并
- ios - Windows VS - Xamarin IOS info.plist 保存后设备方向更改
- java - 使用 Spring Boot 的 Oracle 到 SQL Server 数据传输
- python - 有没有办法进入 for 循环,然后从当前位置反转该循环?
- node.js - cookie 没有从 jquery ajax 中保存