python - 自定义损失函数 keras TypeError
问题描述
我想在 keras 中编写一个自定义损失函数:
def cus_loss_fn(y_true, y_pred):
a = tf.convert_to_tensor(1, dtype=tf.float32)
lSRCC = tf.math.subtract(
a, tf.py_function(
spearmanr,
[tf.cast(y_pred, tf.float32), tf.cast(y_true, tf.float32)],
Tout=tf.float32))
lPLCC = tf.math.subtract(
a, tf.py_function(
pearsonr,
[tf.cast(y_pred, tf.float32), tf.cast(y_true, tf.float32)],
Tout=tf.float32))
alpha = tf.convert_to_tensor(0.5, dtype=tf.float32)
res = tf.math.add(lPLCC , tf.math.multiply(alpha,lSRCC ))
return (tf.convert_to_tensor(res,dtype=tf.float32))
但是当我尝试将数据拟合到我的模型中时出现以下错误。
TypeError: No loop matching the specified signature and casting was found for ufunc add
[[node cus_loss_fn/EagerPyFunc_1 (defined at <ipython-input-16-f7ce00c481ca>:6) ]] [Op:__inference_train_function_4985]
Errors may have originated from an input operation.
Input Source operations connected to node cus_loss_fn/EagerPyFunc_1:
sequential/dense_1/Softmax (defined at <ipython-input-17-a300d9883dd2>:4)
cus_loss_fn/EagerPyFunc (defined at <ipython-input-16-f7ce00c481ca>:4)
Function call stack:
train_function
我应该如何在 keras 中做到这一点?
解决方案
推荐阅读
- flutter - flutter clean 不清理任何东西/也不显示错误
- python - 用于灵活创建日期时间范围的 Python 函数
- html - 在特定位置绘制/放置元素到 SVG 图像上
- nginx - VPS 上的 Laradock 尝试访问域时出现 HTTP ERROR 500 错误
- flutter - Flutter Provider 和 sqflite 集成
- java - 如何使用事件来计算重复事件的结束日期
- java - 如何从具有条件的多对多关系中选择实体
- javascript - 如何等待 Subject.next
- java - Tomcat Servlet Logger 不打印所有内容
- javascript - “此交互失败”除了此错误,我根本找不到问题