首页 > 解决方案 > 如何在 Keras 中查找数组中非零的数量?

问题描述

我试图在 Keras 的自定义损失函数中找到零的数量。

def root_mean_squared_error(y_true, y_pred):

这里输入这个损失函数的地方:

model.compile(optimizer=sgd, loss=root_mean_squared_error,
              metrics=[metrics.mse, root_mean_squared_error])

我正在尝试查找数组中非零值的y_true数量并将我的数字除以该值。

如何在 中找到非零元素的数量y_true

标签: pythontensorflowkeras

解决方案


您可以tf.count_nonzero通过 Keras 后端使用 API。

from keras import backend as K
import numpy as np

def custom_loss(y_true, y_pred):
    return y_pred / K.cast(K.tf.count_nonzero(y_true), K.tf.float32)

y_t = K.placeholder((1,2))
y_p = K.placeholder((1,2))

loss = custom_loss(y_t, y_p)

print(K.get_session().run(loss, {y_t: np.array([[1,1]]), y_p: np.array([[2,4]])}))

结果是

[[1. 2.]]

推荐阅读