python - 实现余弦相似度损失给出了与 Tensorflow 不同的答案
问题描述
我正在用我的自定义 python 脚本实现余弦相似度损失,但它给了我一个与 TensorFlow 截然不同的答案。先看TensorFlow's
答案:-
y_true = [[0., 1.], [1., 1.]]
y_pred = [[0., 1.], [0., 1.]]
loss = tf.keras.losses.CosineSimilarity()
print(loss(y_true, y_pred).numpy())
输出:
>>> -0.8535534
根据 TensorFlow 文档,计算损失的公式是:-
我用普通的python实现了同样的功能:-
def cosine_similarity(y_true, y_pred):
loss = -np.sum(np.linalg.norm(y_true) * np.linalg.norm(y_pred))
return loss
print(cosine_similarity(y_true, y_pred))
输出:
>>> -2.4494897427831783
我不知道为什么我得到-2.45
并且TensorFlow
正在输出-0.85
. 有什么解决方案可以让我的答案与 TensorFlow 相匹配吗?
解决方案
在浏览了一些文档后,
结果tf.keras.losses.CosineSimilarity()
和您的功能不同有两个原因:
- 如这里的示例所示,在
CosineSimiliraty()
函数中,L2_normalisation 沿轴 = 1完成
np.linalg.norm()
因为没有给出轴,所以在整个数组上执行。此外,将结果相加。
y_true = [[0., 1.], [1., 1.]]
y_pred = [[0., 1.], [0., 1.]]
print(tf.math.l2_normalize(y_true,axis=1))
print(np.linalg.norm(y_true))
Outputs
#[[0. 1. ]
# [0.70710677 0.70710677]]
# 1.7320508075688772
# Result from np.linalg.norm() is obtained by summing :
#[[0. 0.57735026]
#[0.57735026 0.57735026]]
- 其次,我不知道为什么,但考虑到上面链接中给出的示例,在对值求和之前,
np.mean
沿同一轴应用。他们可能会忘记在您使用的公式中对其进行精确化。
a=tf.math.l2_normalize(y_true,axis=1)
b=tf.math.l2_normalize(y_pred,axis=1)
print(a)
print(b)
print(np.mean(a*b,axis=1)
print(-np.sum(np.mean(a*b,axis=1)))
#[[0. 1.][0.70710677 0.70710677]]
#[[0. 1.][0. 1.]]
#[0.5 0.35355338]
# -0.8535534
#
- 我不确定,但
np.linalg.norm()
似乎给出了向量/矩阵的范数,其中tensorflow为您提供了相同的归一化矩阵(沿您选择的轴)
所以不要使用np.linalg.norm()
使用Tensorflow函数tf.math.l2_normalized(myarray,axis=1)
def cosine_sim(y_true,y_pred):
norm_true=tf.math.l2_normalize(y_true,axis=1)
norm_pred=tf.math.l2_normalize(y_pred,axis=1)
loss =-np.sum(np.mean(norm_true*norm_pred,axis=1))
return loss
推荐阅读
- macos - 如何在 Mac OS 中分析 dmp 文件?
- jquery - 使用复选框更改带有过滤器的 url 并选择
- python - 如何计算某些连续日期范围的汇总统计数据
- python - 舍入双精度值并转换为整数
- angular - 如何在Angular 4中绑定多个选项/选择
- php - 如果xml节点没有子节点,如何删除它
- python - 使用数组和堆栈递归地评估表达式
- swift - 将数据附加到 Metal 中 MTLBuffer 的现有内容
- android - 如何使用 Cordova 使我的应用程序与旧的 Android 版本兼容?
- powershell - 停止PowerShell的子进程