python - 具有两个输出的 Keras MSE 损失
问题描述
我有一个模型,其输出层是Dense(2)
所以我的输出是 2 个浮点数的列表。
我在Keras 文档中找到了一个类似的示例
>>> y_true = [[0., 1.], [0., 0.]]
>>> y_pred = [[1., 1.], [1., 0.]]
>>> # Using 'auto'/'sum_over_batch_size' reduction type.
>>> mse = tf.keras.losses.MeanSquaredError()
>>> mse(y_true, y_pred).numpy()
0.5
根据示例的输出,我认为它会像这样计算 MSE
first_MSE = mse(y_true[0], y_pred[0])
second_MSE = mse(y_true[1], y_pred[1])
mse = (first_MSE + second_MSE) / 2
执行上述操作,我得到 0.5,如示例所示。这真的是引擎盖下发生的事情吗?
解决方案
是的,MeanSquaredError 首先计算为最后一个轴的平均值,然后是批次的平均值。
在此处计算平方差的最后一个轴的平均值: https ://github.com/tensorflow/tensorflow/blob/7ec285825c713af9bc741b8b65d09dd160ec8806/tensorflow/python/keras/losses.py#L1213
该类MeanSquaredError
使用默认减少,这意味着批量损失,这里:
https ://github.com/tensorflow/tensorflow/blob/7ec285825c713af9bc741b8b65d09dd160ec8806/tensorflow/python/keras/utils/losses_utils.py#L264
推荐阅读
- java - 如何从该 HashMap 中的对象访问对象的 HashMap。(爪哇)
- scala - 如何在Scala中将rdd对象转换为数据框
- php - 在PHP中将单个数组转换为多维数组?
- sql - Oracle SQL 按百分比范围查找人口
- python - 在 GroupBy 之后加入 PySpark
- office365 - Microsoft Graph API 返回 503/504 错误
- sql-server - 在 SQL Server 中不使用双重连接
- python - 如何检查元组列表是否包含所需的元素
- python - 谁能解释这 2 行 python 代码
- python - 无法将输入作为分数字符串转换为浮点数