tensorflow - tf.gather_nd 的用法
问题描述
假设你有一个 3-tensor
data = np.reshape(np.arange(12), [2, 2, 3])
x = tf.constant(data)
将此视为由最后一个索引索引的 2x2 矩阵,我想从第一个矩阵中获取第一列,从第二个矩阵中获取第二列,从第三个矩阵中获取第二列。
我如何使用 tf.gather_nd 来做到这一点?
解决方案
我在网上找到了以下教程,解释了如何处理这类问题:https ://geekyisawesome.blogspot.com/2018/05/fancy-indexing-in-tensorflow-getting.html
假设我们有一个 4x3 矩阵
M = tf.constant(np.arange(12).reshape(4,3))
现在假设您想要第一行的第三个元素、第二行的第二个元素、第三行的第一个元素和第四行的第二个元素。如教程中所述,这可以通过以下方式完成:
idx = tf.constant([2,1,0,1], tf.int32)
x = tf.gather_nd(M, tf.stack([tf.range(M.shape[0]), idx], axis=1))
但是如果 M 的行数未知怎么办?(和 idx 作为适当大小的整数张量)然后 tf.range(M.shape[0]) 将引发错误。我怎么能绕过呢?
推荐阅读
- nginx - 为什么 nginx 将资产重定向到主页?
- java - 无法从谷歌构建“App Engine 标准环境中 Java 的 Bookshelf 应用”示例,pom.xml 有问题吗?
- multithreading - 为什么使用 Iterator::map 生成线程不能并行运行线程?
- scala - Spark:按 ID 创建 JSON 组
- unicode - 在 ODBC 应用程序中插入 Unicode 数据时,如何确定应该使用的编码
- css - 溢出设置为可见最近开始调整父 div 的高度以适合子级
- ads - MoPub 返回“广告服务器未返回广告资源”
- c# - 如何在 URL 中打开我的 GET RestRequest 以登录
- python-3.x - 不能在 Python 中使用 OpenCV GeneralizedHoughTransform 类
- python - Python - 句子结尾和其他句号之间的区别