pytorch - 用于使嵌入相似的 Pytorch 损失函数
问题描述
我正在研究一个嵌入模型,其中有一个 BERT 模型,它接受文本输入并输出一个多维向量。该模型的目标是为相似的文本找到相似的嵌入(高余弦相似度),为不同的文本找到不同的嵌入(低余弦相似度)。
在 mini-batch 模式下训练时,BERT 模型给出一个N*D
维度输出,其中N
是批大小,D
是 BERT 模型的输出维度。
另外,我有一个维度的目标矩阵,如果和在意义上相似,则N*N
它包含1
在第[i, j]
th 位置。sentence[i]
sentence[j]
-1
我想要做的是通过找到 BERT 输出中所有嵌入的余弦相似度并将其与目标矩阵进行比较来找到整个批次的损失/错误。
我所做的只是将张量与其转置相乘,然后取元素 Sigmoid。
scores = torch.matmul(document_embedding, torch.transpose(document_embedding, 0, 1))
scores = torch.sigmoid(scores)
loss = self.bceloss(scores, targets)
但这似乎不起作用。
有没有其他方法可以做到这一点?
PS 我想做的和本文描述的方法类似。
解决方案
要计算两个向量之间的余弦相似度,您将使用nn.CosineSimilarity
. 但是,我认为这不允许您从一组n
向量中获得配对相似性。幸运的是,您可以通过一些张量操作自己实现它。
让我们将您x
的 document_embedding称为嵌入大小。我们将采用和。所以是由 组成的。(n, d)
d
n=3
d=5
x
[x1, x2, x3].T
>>> x = torch.rand(n, d)
tensor([[0.8620, 0.9322, 0.4220, 0.0280, 0.3789],
[0.2747, 0.4047, 0.6418, 0.7147, 0.3409],
[0.6573, 0.3432, 0.5663, 0.2512, 0.0582]])
余弦相似度是归一化的点积。x@x.T
矩阵乘法将为您提供成对点积:其中包含:||x1||²
, <x1/x2>
, <x1/x3>
, <x2/x1>
, ||x2||²
, 等...
>>> sim = x@x.T
tensor([[1.9343, 1.0340, 1.1545],
[1.0340, 1.2782, 0.8822],
[1.1545, 0.8822, 0.9370]])
标准化采用所有范数的向量:||x1||
、||x2||
和||x3||
:
>>> norm = x.norm(dim=1)
tensor([1.3908, 1.1306, 0.9680])
构造包含归一化因子的矩阵:||x1||²
, ||x1||.||x2||
, ||x1||.||x3||
, ||x2||.||x1||
, ||x2||²
, 等...
>>> factor = norm*norm.unsqueeze(1)
tensor([[1.9343, 1.5724, 1.3462],
[1.5724, 1.2782, 1.0944],
[1.3462, 1.0944, 0.9370]])
然后规范化:
>>> sim /= factor
tensor([[1.0000, 0.6576, 0.8576],
[0.6576, 1.0000, 0.8062],
[0.8576, 0.8062, 1.0000]])
或者,避免必须创建范数矩阵的更快方法是在乘法之前进行归一化:
>>> x /= x.norm(dim=1, keepdim=True)
>>> sim = x@x.T
tensor([[1.0000, 0.6576, 0.8576],
[0.6576, 1.0000, 0.8062],
[0.8576, 0.8062, 1.0000]])
对于损失函数,我将nn.CrossEntropyLoss
直接在预测的相似度矩阵和目标矩阵之间应用,而不是应用 sigmoid + BCE。注:nn.CrossEntropyLoss
包括nn.LogSoftmax
.
推荐阅读
- c# - Xamarin 表单切换如何记住和更新用户设置(推送通知)
- php - PrestaShop 1.5 添加 PHP (reCaptcha)
- javascript - 我应该避免在减少中使用对象传播吗?
- javascript - 如何同时使用 module.exports 和 exports.default
- xml - 通过 XSLT 1.0 更改(反转)元素属性值中的日期顺序
- laravel - 将 DB 数组输出转换为 Eloquent 模型以使用 API 资源
- c++ - 为什么带有 and 条件的 if 语句没有按预期工作?
- javascript - Firefox 不加载脚本
- android - 以编程方式生成的 GridLayout 将列推到右侧
- ios - Firebase Crashlytics 迁移和查看仪表板