python - 在pytorch中广播2D索引选择?
问题描述
我有一个 shape 的张量P.shape=[N,k]
和一个 shape 的索引张量ind.shape=[L,N]
,其中( always)中ind[i,j]
有一个列。我希望创建一个新的 dims 张量,其中的功能可以通过以下方式使用 for 循环生成:P[j]
ind[i,j] < k
[L,n]
new= []
num_points = P.shape[-1]
for experiment in range(ind.shape[0]):
new.append(P[torch.arange(num_points),ind[exp]])
new= torch.stack(new)
但是L
真的很大,代码非常慢。
使用repeat
我设法复制了功能
new = P.unsqueeze(1).repeat(1,L,1,1).reshape(-1,*P.shape[1:])
new = new.gather(2,ind.unsqueeze(2)).squeeze(2)
但是L
真的很大,我有一个OOM例外.repeat(1,L,1,1)
。我想知道我是否可以使用广播完成类似的事情?
解决方案
供将来参考
虽然repeat
成本很高,expand
但正是我所需要的
new = P.unsqueeze(1).expand(-1,L,-1,-1).reshape(-1,*P.shape[1:])
new = new.gather(2,ind.unsqueeze(2)).squeeze(2)
推荐阅读
- c# - 在 Sonarqube 中移除这个多余的跳转
- vba - 当我乘以 2 个单元格时出现值错误 - vba
- android - 无法在应用程序 gradle 文件中引用 BuildConfig - Android Studio
- perl - 附加文件间歇性失败
- c# - 如何从必须保留在主线程上的方法统一调用异步方法
- c++ - 是否允许编译器优化掉局部 volatile 变量?
- c# - 当焦点位于单选按钮或复选框上时按“c”时表单关闭
- node.js - 嵌套对象数组更新mongodb
- java - 反转 int 数组,但结果只完成了一半
- c# - 启用 TLS1.2 后 Asp.net 发送邮件失败