keras - 相当于 PyTorch 中 Keras 的 binary_crossentropy?
问题描述
我想将一些代码从 keras 移植到 pytorch,但我在 PyTorch 中找不到与 Keras 的 binary_crossentropy 等效的东西。PyTorch 的 binary_cross_entropy 与 keras 的行为不同。
import torch
import torch.nn.functional as F
input = torch.tensor([[ 0.6845, 0.2454],
[ 0.7186, 0.3710],
[ 0.3480, 0.3374]])
target = torch.tensor([[ 0., 1.],
[ 1., 1.],
[ 1., 1.]])
F.binary_cross_entropy(input, target, reduce=False)
#tensor([[ 1.1536, 1.4049],
# [ 0.3305, 0.9916],
# [ 1.0556, 1.0865]])
import keras.backend as K
K.eval(K.binary_crossentropy(K.variable(input.detach().numpy()), K.variable(target.detach().numpy())))
#[[11.032836 12.030124]
#[ 4.486187 10.02776 ]
#[10.394435 10.563424]]
有谁知道为什么这两个结果不同?谢谢!
解决方案
Keras 二元交叉熵采用y_true, y_pred
,而 Pytorch 采用相反的顺序,因此您需要将 Keras 行更改为
K.eval(K.binary_crossentropy(K.variable(target.detach().numpy()), K.variable(input.detach().numpy())))
通过这种方式,您可以获得正确的输出:
array([[ 1.15359652, 1.40486574],
[ 0.33045045, 0.99155325],
[ 1.05555284, 1.0864861 ]], dtype=float32)
推荐阅读
- pandas - Pandas - 输出中不存在 KeyError 作为列
- dictionary - 将 Ansible 字典列表转换为单个列表未按预期显示
- php - 如何在 shopify 中获取“在线商店”销售渠道产品?
- android - 是否可以强制卸载我的应用程序?
- puppeteer - 无法在无头浏览器中加入由 lib-jitsi-meet 创建的会议
- gnuplot - 你如何使用 GNUplot 画一个圆?
- javascript - input type="date" 将默认值设置为今天的日期
- cefsharp - cefsharp (chromiumwebbrowser) 的默认缓存路径在哪里?
- java - Mockito 模拟 retrytemplate.execute 并返回模拟响应
- flutter - 颤振错误:TabBarView 没有 TabController