python - 相当于torch.gather的张量流
问题描述
我有一个张量的形状(16, 4096, 3)
。我有另一个形状指数张量(16, 32768, 3)
。我正在尝试收集这些值dim=1
。这最初是在 pytorch 中使用收集功能完成的,如下所示-
# a.shape (16L, 4096L, 3L)
# idx.shape (16L, 32768L, 3L)
b = a.gather(1, idx)
# b.shape (16L, 32768L, 3L)
请注意,输出的大小b
与idx
. 但是,当我应用gather
tensorflow 的功能时,我得到了完全不同的输出。发现输出维度不匹配,如下所示 -
b = tf.gather(a, idx, axis=1)
# b.shape (16, 16, 32768, 3, 3)
我也尝试使用tf.gather_nd
但徒劳无功。见下文-
b = tf.gather_nd(a, idx)
# b.shape (16, 32768)
为什么我会得到不同形状的张量?我想得到与 pytorch 计算的相同形状的张量。
换句话说,我想知道torch.gather的tensorflow等价物。
解决方案
对于 2D 情况,有一种方法可以做到:
# a.shape (16L, 10L)
# idx.shape (16L,1)
idx = tf.stack([tf.range(tf.shape(idx)[0]),idx[:,0]],axis=-1)
b = tf.gather_nd(a,idx)
但是,对于 ND 情况,这种方法可能非常复杂
推荐阅读
- java - 休眠条件:关系“my_table”不存在
- ios - 是否可以在 ios 上读取 NFC 芯片的 Uid?
- arrays - mongodb 在数组中发送用户 ID 作为查询并返回集合中存在的用户的 json 对象
- vue.js - 在 vue 我对数据库和生命周期有疑问
- android - setTextColor 可编程使用颜色值
- plugins - Flutter 包仅适用于条带插件
- regex - 将索引增加 1 - 将 C 中的代码转换为 MATLAB
- github - GitHub 铅笔(编辑)按钮消失
- javascript - 如何读取 API 文档中的函数参数?
- ios - 无法访问 gcm 数据响应