python - 带有轴参数的Tensorflow tf.gather
问题描述
我正在使用 tensorflowtf.gather
从多维数组中获取元素,如下所示:
import tensorflow as tf
indices = tf.constant([0, 1, 1])
x = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
result = tf.gather(x, indices, axis=1)
with tf.Session() as sess:
selection = sess.run(result)
print(selection)
这导致:
[[1 2 2]
[4 5 5]
[7 8 8]]
我想要的是:
[1
5
8]
如何使用tf.gather
在指定轴上应用单个索引?(与此答案中指定的解决方法相同的结果:https ://stackoverflow.com/a/41845855/9763766 )
解决方案
您需要将 转换indices
为full indices
,并使用gather_nd
。可以通过这样做来实现:
result = tf.squeeze(tf.gather_nd(x,tf.stack([tf.range(indices.shape[0])[...,tf.newaxis], indices[...,tf.newaxis]], axis=2)))
推荐阅读
- string - 如何从嵌套列表中删除一定长度的字符串?
- reactjs - 将我的 ReactJS Web 客户端直接连接到 DialogFlow
- django - Django Angular 403 Django 不接受 CSRF-cookie:“CSRF 令牌丢失或不正确。”
- python - django 在模型字段访问之前运行代码?(代理模型字段)
- javascript - 你如何更改谷歌 cse 中的默认细化标签?
- sql-server - T-SQL 为 MIN 和 MAX 返回不同的行值
- python - 不理解这种递归斐波那契实现
- mysql - Laravel MySQL:为电子商务设计多(三)级类别结构
- javascript - 令牌过期 1 小时后,acquireToken 不起作用
- php - 关闭对非链接在 htaccess 中重写的页面的访问