首页 > 解决方案 > tf.gather_nd 的用法

问题描述

假设你有一个 3-tensor

data = np.reshape(np.arange(12), [2, 2, 3])
x = tf.constant(data)

将此视为由最后一个索引索引的 2x2 矩阵,我想从第一个矩阵中获取第一列,从第二个矩阵中获取第二列,从第三个矩阵中获取第二列。

我如何使用 tf.gather_nd 来做到这一点?

标签: tensorflow

解决方案


我在网上找到了以下教程,解释了如何处理这类问题:https ://geekyisawesome.blogspot.com/2018/05/fancy-indexing-in-tensorflow-getting.html

假设我们有一个 4x3 矩阵

M = tf.constant(np.arange(12).reshape(4,3))

现在假设您想要第一行的第三个元素、第二行的第二个元素、第三行的第一个元素和第四行的第二个元素。如教程中所述,这可以通过以下方式完成:

idx = tf.constant([2,1,0,1], tf.int32)
x = tf.gather_nd(M, tf.stack([tf.range(M.shape[0]), idx], axis=1))

但是如果 M 的行数未知怎么办?(和 idx 作为适当大小的整数张量)然后 tf.range(M.shape[0]) 将引发错误。我怎么能绕过呢?


推荐阅读