首页 > 解决方案 > 在pytorch中广播2D索引选择?

问题描述

我有一个 shape 的张量P.shape=[N,k]和一个 shape 的索引张量ind.shape=[L,N],其中( always)中ind[i,j] 有一个列。我希望创建一个新的 dims 张量,其中的功能可以通过以下方式使用 for 循环生成:P[j]ind[i,j] < k[L,n]

    new= []
    num_points = P.shape[-1]
    for experiment in range(ind.shape[0]):
        new.append(P[torch.arange(num_points),ind[exp]])
    new= torch.stack(new)

但是L真的很大,代码非常慢。
使用repeat我设法复制了功能

new = P.unsqueeze(1).repeat(1,L,1,1).reshape(-1,*P.shape[1:])
new = new.gather(2,ind.unsqueeze(2)).squeeze(2)

但是L真的很大,我有一个OOM例外.repeat(1,L,1,1)。我想知道我是否可以使用广播完成类似的事情?

标签: pythonpytorch

解决方案


供将来参考
虽然repeat成本很高,expand但正是我所需要的

new = P.unsqueeze(1).expand(-1,L,-1,-1).reshape(-1,*P.shape[1:])
new = new.gather(2,ind.unsqueeze(2)).squeeze(2)

推荐阅读