python - 如何在 keras 损失函数中正确使用 from_logits 进行二进制分类?
问题描述
这是一个使用 tensorflow 的示例神经网络,
x = tf.keras.layers.Input((None,))
x = tf.keras.layers.Dense(100)(x)
x1 = tf.keras.layers.Dense(1)(x)
model = tf.keras.Model(inputs=x, outputs=x1)
model.compile("adam", loss=tf.keras.losses.BinaryCrossentropy(from_logits=True))
model.fit(x, y)
我在二进制损失中正确使用了 from_logits 吗?或者我应该将我的输出层更改为,
x1 = tf.keras.layers.Dense(1, activation="sigmoid")
任何人都可以消除这种困惑吗?
我在没有from_logits
损失函数的情况下尝试了以下模型,并且得到了很好的结果,但是如果我使用 from_logits 则不会得到很好的结果。
x = tf.keras.layers.Input((None,))
x = tf.keras.layers.Dense(100)(x)
x1 = tf.keras.layers.Dense(1)(x)
model = tf.keras.Model(inputs=x, outputs=x1)
model.compile("adam", loss=tf.keras.losses.BinaryCrossentropy())
model.fit(x, y)
解决方案
我一直在做一些研究,因为我在尝试使用 Keras 在 Tensorflow 2.0 中复制 Tensorflow 和 Pytorch 架构时遇到了一些类似的问题。
在这篇文章中,您可以简要了解该from_logits
参数存在的原因。简而言之,训练二元分类模型的原始 Tensorflow 方法似乎是从无界层(例如使用线性激活函数)获取输出并使用“logits”计算二元交叉熵。在神经网络上下文中,“logits”是最后一个无界层的输出。但是,这不是“logits”的正确数学定义。
我看过一些帖子说 Keras 中用于训练的 Sigmoid 激活函数 + Binary Crossentropy 是不稳定的。尽管如此,还有其他一些帖子试图表明情况并非如此。例如,在这篇文章中,似乎有几个很好的理由说明它不稳定。
在我看来,Keras 或 Tensorflow 2.0 中最好的方法是在最后一层使用不带 logits 的 BinaryCrossentropy 和 sigmoid 激活函数。这更简单,它应该可以正常工作。
推荐阅读
- ansible - 是否可以使用 Ansible lineinfile 模块添加同一行?
- python - 无需在 Python 中进行身份验证即可将大文件上传到 S3
- git - 在编写日志条目时防止 GNU Emacs 退出
- asp.net - 尝试在 Microsoft Azure 上部署 ASP.NET 动态数据实体网站
- android - 旋转到横向时的Exoplayer方向问题
- python - Evernote - OAuth 访问令牌 - Python
- unity3d - 这个 Unity 脚本没有理由让精灵移动,但它会移动
- php - 使用 pdo 时数据未插入数据库
- python - Python MySQL连接器未连接
- spring-boot - 将 JSF 依赖添加到 Spring Boot 会阻止从文件系统中读取静态资源