keras - 将参数传递给 keras 损失函数的语法是什么?
问题描述
我试图用 from_logits=True 给我的 keras 神经网络分类交叉熵损失。但是,我不确定如何将其传递到代码中,因为它要求我指定目标和输出。
通常我可以使用:
network.compile(sgd, loss='categorical_crossentropy'),
但现在我不得不试试这个:
network.compile(sgd, loss=categorical_crossentropy(from_logits=True))
这给了我一个错误:
TypeError: categorical_crossentropy() missing 2 required positional arguments: 'target' and 'output'
我能想到的最好的是:
network.compile(sgd, loss=categorical_crossentropy(y_true, network.output, from_logits=True))
我不知道要为 y_true 放什么,因为这不是网络的一部分。我在网上浏览了一下,但没有遇到任何指定如何执行此操作的内容,包括奇怪的 keras 文档。
解决方案
Keras 损失严格地需要两个参数:(y_true
地面实况数据)和y_pred
(模型的输出)。
如果要使用具有不同签名的函数,则必须将其包装以遵循正确的签名。
import keras.backend as K
def cc_from_logits(y_true, y_pred):
return K.categorical_crossentropy(y_true, y_pred, from_logits=True, axis=-1)
model.compile(loss=cc_from_logits)
我非常相信这cc_with_logits
会带来与 softmax + 'categorical_crossentropy'
.
推荐阅读
- javascript - 样式...的组件是动态创建的。您可能会看到此警告,因为您在另一个组件中调用了 styled
- powershell - 文件大小加倍,同时替换变量中的值并使用不同的名称输出
- html - CSS Gradient - 太多的颜色中断
- javascript - React - 如何在不在子组件中添加额外代码的情况下从一个组件调用方法到另一个组件
- oracle-apex - 交互式网格 - 交互式网格的数据库约束错误处理
- map-matching - 使用不同的方式操作地图?
- r - Rstudio:为什么适合统计的 M2 适用于四个分级响应模型中的单个案例?
- android - 应用程序关闭时的定期任务和通知
- iot - 累积数据类型处理
- javascript - 在 JavaScript 画布 API 中设置允许的绘图区域