首页 > 解决方案 > 如何实现 K.sum() 以便聚合所有样本的所有元素?

问题描述

我正在为输入和输出灰度图像的 Keras 网络创建自定义损失。这种损失在执行期间应该执行的一件事是计算整个数据集中所有像素值的总和。

例如,如果第一张图像有 1 个白色像素和 3 个黑色,第二张图像有 2 个白色像素和 2 个黑色,那么我想返回 255 * 3 = 765。

但是,由于某种原因,简单的解决方案(下面的示例)似乎计算了图像之间的 batch_size * mean sum_elements。你能帮我解决这个问题吗?

def my_loss(y_true, y_pred):
    sum_elements = K.sum(y_true)
    return sum_elements

标签: tensorflowmachine-learningkerasneural-networkloss-function

解决方案


推荐阅读