首页 > 解决方案 > 张量流中张量对象的非连续索引切片(高级索引,如numpy)

问题描述

我研究了 tensorflow 中不同的切片方式,即tf.gathertf.gather_nd. 在 tf.gather 中,它只对一个维度进行切片,并且在tf.gather_nd其中只接受一个indices应用于输入张量的切片。

我需要的是不同的,我想使用两个不同的张量对输入张量进行切片;一个切片在行上,第二个切片在列上,它们的形状不一定相同。

例如:

假设这是我想要提取其中一部分的输入张量。

input_tf = tf.Variable([ [9.968594,  8.655439,  0.,        0.       ],
                         [0.,        8.3356,    0.,        8.8974   ],
                         [0.,        0.,        6.103182,  7.330564 ],
                         [6.609862,  0.,        3.0614321, 0.       ],
                         [9.497023,  0.,        3.8914037, 0.       ],
                         [0.,        8.457685,  8.602337,  0.       ],
                         [0.,        0.,        5.826657,  8.283971 ],
                         [0.,        0.,        0.,        0.       ]])

第二个是:

 rows_tf = tf.constant (
[[1, 2, 5],
 [1, 2, 5],
 [1, 2, 5],
 [1, 4, 6],
 [1, 4, 6],
 [2, 3, 6],
 [2, 3, 6],
 [2, 4, 7]])

第三张量:

columns_tf = tf.constant(
[[1],
 [2],
 [3],
 [2],
 [3],
 [2],
 [3],
 [2]])

现在,我想input_tf使用rows_tfand进行切片columns_tf[1 2 5]行和[1]中的索引columns_tf。同样,[1 2 5]带有[2]in 的行columns_tf

或者,[1 4 6][2].

总体而言,中的每个索引rows_tf,与中相同的索引columns_tf都会提取部分input_tf

因此,预期的输出将是:

[[8.3356,    0.,        8.457685 ],
 [0.,        6.103182,  8.602337 ],
 [8.8974,    7.330564,  0.       ],
 [0.,        3.8914037, 5.826657 ],
 [8.8974,    0.,        8.283971 ],
 [6.103182,  3.0614321, 5.826657 ],
 [7.330564,  0.,        8.283971 ],
 [6.103182,  3.8914037, 0.       ]]

例如,这里的第一行[8.3356, 0., 8.457685 ]是使用

rows in rows_tf [1,2,5] and column in columns_tf [1](row 1 and column 1, row 2 and column 1 and row 5 and column 1 in the input_tf)

有几个关于 tensorflow 切片的问题,尽管他们使用了tf.gatherortf.gather_nd并且tf.stack没有给出我想要的输出。

无需提及,numpy我们可以通过调用:input_tf[rows_tf, columns_tf].

我还查看了这个高级索引,它试图模拟 numpy 中可用的高级索引,但它仍然不像 numpy 灵活https://github.com/SpinachR/ubuntuTest/blob/master/beautifulCodes/tensorflow_advanced_index_slicing.ipynb

这是我尝试过的不正确的方法:

tf.gather(tf.transpose(tf.gather(input_tf,rows_tf)),columns_tf)

该代码的维度输出(8, 1, 3, 8)是完全不正确的。

提前致谢!

标签: pythontensorflowkerasslice

解决方案


这个想法是首先将稀疏索引(通过连接行索引和列索引)作为一个列表。然后,您可以使用gather_nd来检索值。


tf.reset_default_graph()
input_tf = tf.Variable([ [9.968594,  8.655439,  0.,        0.       ],
                         [0.,        8.3356,    0.,        8.8974   ],
                         [0.,        0.,        6.103182,  7.330564 ],
                         [6.609862,  0.,        3.0614321, 0.       ],
                         [9.497023,  0.,        3.8914037, 0.       ],
                         [0.,        8.457685,  8.602337,  0.       ],
                         [0.,        0.,        5.826657,  8.283971 ],
                         [0.,        0.,        0.,        0.       ]])
rows_tf = tf.constant (
[[1, 2, 5],
 [1, 2, 5],
 [1, 2, 5],
 [1, 4, 6],
 [1, 4, 6],
 [2, 3, 6],
 [2, 3, 6],
 [2, 4, 7]])
columns_tf = tf.constant(
[[1],
 [2],
 [3],
 [2],
 [3],
 [2],
 [3],
 [2]])
rows_tf = tf.reshape(rows_tf, shape=[-1, 1])
columns_tf = tf.reshape(
    tf.tile(columns_tf, multiples=[1, 3]), 
    shape=[-1, 1])
sparse_indices = tf.reshape(
    tf.concat([rows_tf, columns_tf], axis=-1), 
    shape=[-1, 2])

v = tf.gather_nd(input_tf, sparse_indices)
v = tf.reshape(v, [-1, 3])

with tf.Session() as sess:
  sess.run(tf.initialize_all_variables())
  #print 'rows\n', sess.run(rows_tf)
  #print 'columns\n', sess.run(columns_tf)
  print sess.run(v)

结果将是:

[[ 8.3355999   0.          8.45768547]
 [ 0.          6.10318184  8.60233688]
 [ 8.8973999   7.33056402  0.        ]
 [ 0.          3.89140368  5.82665682]
 [ 8.8973999   0.          8.28397083]
 [ 6.10318184  3.06143212  5.82665682]
 [ 7.33056402  0.          8.28397083]
 [ 6.10318184  3.89140368  0.        ]]

推荐阅读