首页 > 解决方案 > 在 PyTorch 中计算批量成对 Sinkhorn 距离

问题描述

我有两个张量,它们的形状相同。我想使用GeomLoss.

我试过的:

import torch
import geomloss  # pip install git+https://github.com/jeanfeydy/geomloss

a = torch.rand((8,4))
b = torch.rand((8,4))

geomloss.SamplesLoss('sinkhorn')(a,b)
# ^ input shape [batch, feature_dim]
# will return a scalar value

geomloss.SamplesLoss('sinkhorn')(a.unsqueeze(1),b.unsqueeze(1))  
# ^ input shape [batch, n_points, feature_dim]
# will return a tensor of size [batch] of distances between a[i] and b[i] for each i

但是,我想计算结果张量应该为 size 的成对距离[batch, batch]。为了实现这一点,我尝试了以下使用广播:

geomloss.SamplesLoss('sinkhorn')(a.unsqueeze(0), b.unsqueeze(1))

但我收到了这个错误信息:

ValueError:样本应该具有相同的批量大小xy

标签: pytorch

解决方案


由于文档没有提供有关如何使用距离的前向功能的示例。这是一种方法,它需要您调用距离函数batch时间。

我们将逐行构造距离矩阵。线i对应于距离a[i]<->b[0], a[i]<->b[1], 到a[i]<->b[batch]。为此,我们需要为每一行构造i一个(8x4)重复版本的张量a[i]

这将做:

a_i = torch.stack(8*[a[i]], dim=0)

a[i]然后我们计算每个批次之间的距离b

dist(a_i.unsqueeze(1), b.unsqueeze(1))

总共有几batch行,我们可以构造我们的最终张量stack


这是完整的代码:

batch = a.shape[0]
dist = geomloss.SamplesLoss('sinkhorn')
distances = [dist(torch.stack(batch*[a[i]]).unsqueeze(1), b.unsqueeze(1)) for i in range(batch)]
D = torch.stack(distances)

推荐阅读