python - Output of custom loss in Keras
问题描述
I know there are many questions treating custom loss functions in Keras but I've been unable to answer this even after 3 hours of googling.
Here is a very simplified example of my problem. I realize this example is pointless but I provide it for simplicity, I obviously need to implement something more complicated.
from keras.backend import binary_crossentropy
from keras.backend import mean
def custom_loss(y_true, y_pred):
zeros = tf.zeros_like(y_true)
index_of_zeros = tf.where(tf.equal(zeros, y_true))
ones = tf.ones_like(y_true)
index_of_ones = tf.where(tf.equal(ones, y_true))
zero = tf.gather(y_pred, index_of_zeros)
one = tf.gather(y_pred, index_of_ones)
loss_0 = binary_crossentropy(tf.zeros_like(zero), zero)
loss_1 = binary_crossentropy(tf.ones_like(one), one)
return mean(tf.concat([loss_0, loss_1], axis=0))
I do not understand why training the network with the above loss function on a two class dataset does not yield the same result as training with the built in binary-crossentropy
loss function.
Thank you!
EDIT: I edited the code snippet to include the mean as per comments below. I still get the same behavior however.
解决方案
我终于弄明白了。tf.where
当形状为“未知”时,该函数的行为非常不同。要修复上面的代码片段,只需在声明函数后插入以下行:
y_pred = tf.reshape(y_pred, [-1])
y_true = tf.reshape(y_true, [-1])
推荐阅读
- asynchronous - Kotlin 中的异步计划作业
- python - 使用 Keras 和 seq2seq 模型无法调用“张量”对象
- python - 如果我有 python 3.7,如何使用 python 3.6?
- php - 如何在 PHP Curl 或 PHP Soapclient 中进行 XML 请求?
- tensorflow - 在 keras 中获取中间层的输出会引发名称错误
- php - 从 MySQL 插入和显示图像
- android - 无法为 org.gradle.api.Project 类型的项目“:app”获取未知属性“src”。在 Ubuntu 上
- performance - 访问全局数据是否比访问本地数据更快?
- javascript - 使用 PhantomJS 在同一行打印数据
- react-native - React Native fb-sdk 找不到符号 CallbackManager