首页 > 解决方案 > 用于使嵌入相似的 Pytorch 损失函数

问题描述

我正在研究一个嵌入模型,其中有一个 BERT 模型,它接受文本输入并输出一个多维向量。该模型的目标是为相似的文本找到相似的嵌入(高余弦相似度),为不同的文本找到不同的嵌入(低余弦相似度)。

在 mini-batch 模式下训练时,BERT 模型给出一个N*D维度输出,其中N是批大小,D是 BERT 模型的输出维度。

另外,我有一个维度的目标矩阵,如果和在意义上相似,则N*N它包含1在第[i, j]th 位置。sentence[i]sentence[j]-1

我想要做的是通过找到 BERT 输出中所有嵌入的余弦相似度并将其与目标矩阵进行比较来找到整个批次的损失/错误。

我所做的只是将张量与其转置相乘,然后取元素 Sigmoid。

scores = torch.matmul(document_embedding, torch.transpose(document_embedding, 0, 1))
scores = torch.sigmoid(scores)

loss = self.bceloss(scores, targets)

但这似乎不起作用。

有没有其他方法可以做到这一点?

PS 我想做的和本文描述的方法类似。

标签: pytorchtensorembeddingbert-language-modelloss

解决方案


要计算两个向量之间的余弦相似度,您将使用nn.CosineSimilarity. 但是,我认为这不允许您从一组n向量中获得配对相似性。幸运的是,您可以通过一些张量操作自己实现它。

让我们将您x的 document_embedding称为嵌入大小。我们将采用和。所以是由 组成的。(n, d)dn=3d=5x[x1, x2, x3].T

>>> x = torch.rand(n, d)
tensor([[0.8620, 0.9322, 0.4220, 0.0280, 0.3789],
        [0.2747, 0.4047, 0.6418, 0.7147, 0.3409],
        [0.6573, 0.3432, 0.5663, 0.2512, 0.0582]])

余弦相似度是归一化的点积。x@x.T 矩阵乘法将为您提供成对点积:其中包含:||x1||², <x1/x2>, <x1/x3>, <x2/x1>, ||x2||², 等...

>>> sim = x@x.T
tensor([[1.9343, 1.0340, 1.1545],
        [1.0340, 1.2782, 0.8822],
        [1.1545, 0.8822, 0.9370]])

标准化采用所有范数的向量:||x1||||x2||||x3||

>>> norm = x.norm(dim=1)
tensor([1.3908, 1.1306, 0.9680])

构造包含归一化因子的矩阵:||x1||², ||x1||.||x2||, ||x1||.||x3||, ||x2||.||x1||, ||x2||², 等...

>>> factor = norm*norm.unsqueeze(1)
tensor([[1.9343, 1.5724, 1.3462],
        [1.5724, 1.2782, 1.0944],
        [1.3462, 1.0944, 0.9370]])

然后规范化:

>>> sim /= factor
tensor([[1.0000, 0.6576, 0.8576],
        [0.6576, 1.0000, 0.8062],
        [0.8576, 0.8062, 1.0000]])

或者,避免必须创建范数矩阵的更快方法是在乘法之前进行归一化:

>>> x /= x.norm(dim=1, keepdim=True)
>>> sim = x@x.T
tensor([[1.0000, 0.6576, 0.8576],
        [0.6576, 1.0000, 0.8062],
        [0.8576, 0.8062, 1.0000]])

对于损失函数,我将nn.CrossEntropyLoss直接在预测的相似度矩阵和目标矩阵之间应用,而不是应用 sigmoid + BCE。注:nn.CrossEntropyLoss包括nn.LogSoftmax.


推荐阅读