python-3.x - 如何对三元组损失函数的掩码计算进行矢量化
问题描述
假设我有一个lst
长度为 的数字列表N
,以及两个数字epsilon
和tau
。我想找到(N,N,N)
掩码矩阵mask
,使得mask[i][j][k]=1
当且仅当
abs(lst[i] - lst[j]) <= epsilon and abs(lst[i] - lst[k]) >= tau
这是我尝试过的:
d_mat = torch.cdist(lst.unsqueeze(0), lst.unsqueeze(0))
within_eps = torch.where(dmat <= eps, 1, 0)
over_tau = torch.where(dmat >= tau, 1, 0)
mask = torch.zeros((N,N,N))
for i in range(N):
for j in range(N):
for k in range(N):
if within_eps[i][j] == 1 and over_tau[i][k] == 1:
mask[i][j][k] = 1
else:
mask[i][j][k] = 0
所以基本上我天真地做了。你能用步骤告诉我你是如何为此提出矢量化的吗?
解决方案
您成功创建了 2ddmat
成对距离。现在您可以torch.logical_and
用于创建蒙版:
mask = torch.logical_and(dmat[..., None] <= eps, dmat[:, None, :] >= tau)
如果您想明确说明距离计算(并且效率较低),您可以:
mask = torch.logical_and(torch.abs(lst[:, None, None] - lst[None, :, None]) <= eps,
torch.abs(lst[:, None, None] - lst[None, None, :]) >= tau)