python - 在自定义损失 mse 函数中添加 if 条件
问题描述
我正在尝试使用带有 keras 的自定义损失 mse 函数来更加强调 DNN 模型的小价值。
我确实尝试过这样的事情:
import keras.backend as K
def scaled_mse(y_true,y_pred):
loss = K.square(y_pred - y_true)
if y_true>0.1:
loss=loss/K.abs(y_true)
return loss
我的 ML 模型的输入是 41 值。输出只有 1 个值。
但它返回此错误: OperatorNotAllowedInGraphError: using a tf.Tensor
as a Python bool
is not allowed in Graph execution. 使用 Eager 执行或使用 @tf.function 修饰此函数。
谢谢你的帮助 !
解决方案
您的代码的问题是if
每个 keras 后端的操作员都不可区分。但是,如果您考虑此操作的差异化,则相当简单;仅当条件评估为 True 时,梯度才取决于一项,如果条件评估为 False,则同样取决于另一项。因此,修复应该很简单。Keras 后端提供了switch()
我认为本质上是条件语句的可微分形式的操作。尝试使用它。
switch()
接受三个参数:第一个是条件表达式,第二个是张量,如果条件计算为真,则从该张量中获取值,第三个是张量,如果条件计算为假,则从该张量中获取值。所以你的代码可能看起来像这样:
import keras.backend as K
def scaled_mse(y_true,y_pred):
# both loss and scale loss are size 41
loss = K.square(y_pred - y_true)
scaled_loss = loss/K.abs(y_true)
# first arg (bool expression) will evaluate to a bool tensor of size 41,
# which then indicates from which tensor to select the value for each
# element of the output tensor
scale_mse_loss = K.switch((y_true > 0.1),scaled_loss,loss)
return scale_mse_loss
推荐阅读
- sql-server - 如何在将数据添加到表中时动态插入日期?
- github - 什么是 GraphQL 查询以确定组织中是否存在 GitHub 包?
- servicestack - 是否可以在 Visual Studio 2017 或 2019 中使用 ServiceStack 模板?
- python - 改进 tensorflow 训练模型
- python - 从大流中选择一个随机元素有什么大不了的?
- javascript - 将 Json 数据映射到新的键值对 JavaScript 对象
- c++ - 如何使用 string::replace 方法写入文件?
- python - Match all identifiers in a string
- shopify - 推出新的 Shopify 主题和最佳方法
- r - 使用 R 闪亮的应用程序创建箱线图的问题