首页 > 解决方案 > 切片类似于 numpy np.ix_ 的 2D 张量

问题描述

我在这里学习了如何在一维上切割张量。

我已经学习了如何对 2D 张量进行切片,并在此处给出特定值的 1D 张量。

两者都使用tf.gather(),但我很确定我需要tf.gather_nd(),尽管我显然用错了。

在 numpy 中,我有一个 5x5 2D 数组,我可以通过使用np.ix_()行和列索引来切片一个 2x2 数组(我总是需要相同的行和列索引,从而产生一个方阵):

import numpy as np

a = np.array([[1,2,3,4,5],[2,1,6,7,8],[3,6,1,9,10],[4,7,9,1,11],[5,8,10,11,1]])

a
array([[ 1,  2,  3,  4,  5],
      [ 2,  1,  6,  7,  8],
      [ 3,  6,  1,  9, 10],
      [ 4,  7,  9,  1, 11],
      [ 5,  8, 10, 11,  1]])
a[np.ix_([1,3], [1,3])]
array([[1, 7],
      [7, 1]])

阅读tf.gather_nd()文档我认为这是在 TF 中执行此操作的方法,但我使用错了:

import tensorflow as tf

a = tf.constant([[1,2,3,4,5],[2,1,6,7,8],[3,6,1,9,10],[4,7,9,1,11],[5,8,10,11,1]])

tf.gather_nd(a, [[1,3], [1,3]])
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([7, 7])>

我将不得不做类似的事情:

tf.gather_nd(a, [[[1,1], [1,3]],[[3,1],[3,3]]])
<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[1, 7],
      [7, 1]])>

这使我陷入了另一个我不喜欢的兔子洞。当然,我的索引向量要长得多。

顺便说一句,我的索引本身就是一维整数张量。所以底线我想a用与我相同的行和列索引来切片np._ix(),我的索引类似于:

idx = tf.constant([1, 3])

# tf.gather_nd(a, indices = "something with idx")

标签: pythontensorflow

解决方案


要使用长度为 d 的 1D 张量对 nxn 2D 数组进行切片,从而生成具有指定索引的 dxd 2D 数组,可以使用tf.repeat,tf.tile然后来完成tf.stack

n = 5
a = tf.constant(np.arange(n * n).reshape(n, n)) # 2D nxn array
idx = [1,2,4] # 1D tensor with length d
d = tf.shape(idx)[0]
ix_ = tf.reshape(tf.stack([tf.repeat(idx,d),tf.tile(idx,[d])],1),[d,d,2])
target = tf.gather_nd(a,ix_) # 2D dxd array
print(a)
print(target)

预期产出:

tf.Tensor(
[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]
 [15 16 17 18 19]
 [20 21 22 23 24]], shape=(5, 5), dtype=int64)
tf.Tensor(
[[ 6  7  9]
 [11 12 14]
 [21 22 24]], shape=(3, 3), dtype=int64)

推荐阅读