python - 张量流中张量对象的非连续索引切片(高级索引,如numpy)
问题描述
我研究了 tensorflow 中不同的切片方式,即tf.gather
和tf.gather_nd
. 在 tf.gather 中,它只对一个维度进行切片,并且在tf.gather_nd
其中只接受一个indices
应用于输入张量的切片。
我需要的是不同的,我想使用两个不同的张量对输入张量进行切片;一个切片在行上,第二个切片在列上,它们的形状不一定相同。
例如:
假设这是我想要提取其中一部分的输入张量。
input_tf = tf.Variable([ [9.968594, 8.655439, 0., 0. ],
[0., 8.3356, 0., 8.8974 ],
[0., 0., 6.103182, 7.330564 ],
[6.609862, 0., 3.0614321, 0. ],
[9.497023, 0., 3.8914037, 0. ],
[0., 8.457685, 8.602337, 0. ],
[0., 0., 5.826657, 8.283971 ],
[0., 0., 0., 0. ]])
第二个是:
rows_tf = tf.constant (
[[1, 2, 5],
[1, 2, 5],
[1, 2, 5],
[1, 4, 6],
[1, 4, 6],
[2, 3, 6],
[2, 3, 6],
[2, 4, 7]])
第三张量:
columns_tf = tf.constant(
[[1],
[2],
[3],
[2],
[3],
[2],
[3],
[2]])
现在,我想input_tf
使用rows_tf
and进行切片columns_tf
。[1 2 5]
行和[1]
中的索引columns_tf
。同样,[1 2 5]
带有[2]
in 的行columns_tf
。
或者,[1 4 6]
与[2]
.
总体而言,中的每个索引rows_tf
,与中相同的索引columns_tf
都会提取部分input_tf
。
因此,预期的输出将是:
[[8.3356, 0., 8.457685 ],
[0., 6.103182, 8.602337 ],
[8.8974, 7.330564, 0. ],
[0., 3.8914037, 5.826657 ],
[8.8974, 0., 8.283971 ],
[6.103182, 3.0614321, 5.826657 ],
[7.330564, 0., 8.283971 ],
[6.103182, 3.8914037, 0. ]]
例如,这里的第一行[8.3356, 0., 8.457685 ]
是使用
rows in rows_tf [1,2,5] and column in columns_tf [1](row 1 and column 1, row 2 and column 1 and row 5 and column 1 in the input_tf)
有几个关于 tensorflow 切片的问题,尽管他们使用了tf.gather
ortf.gather_nd
并且tf.stack
没有给出我想要的输出。
无需提及,numpy
我们可以通过调用:input_tf[rows_tf, columns_tf]
.
我还查看了这个高级索引,它试图模拟 numpy 中可用的高级索引,但它仍然不像 numpy 灵活https://github.com/SpinachR/ubuntuTest/blob/master/beautifulCodes/tensorflow_advanced_index_slicing.ipynb
这是我尝试过的不正确的方法:
tf.gather(tf.transpose(tf.gather(input_tf,rows_tf)),columns_tf)
该代码的维度输出(8, 1, 3, 8)
是完全不正确的。
提前致谢!
解决方案
这个想法是首先将稀疏索引(通过连接行索引和列索引)作为一个列表。然后,您可以使用gather_nd
来检索值。
tf.reset_default_graph()
input_tf = tf.Variable([ [9.968594, 8.655439, 0., 0. ],
[0., 8.3356, 0., 8.8974 ],
[0., 0., 6.103182, 7.330564 ],
[6.609862, 0., 3.0614321, 0. ],
[9.497023, 0., 3.8914037, 0. ],
[0., 8.457685, 8.602337, 0. ],
[0., 0., 5.826657, 8.283971 ],
[0., 0., 0., 0. ]])
rows_tf = tf.constant (
[[1, 2, 5],
[1, 2, 5],
[1, 2, 5],
[1, 4, 6],
[1, 4, 6],
[2, 3, 6],
[2, 3, 6],
[2, 4, 7]])
columns_tf = tf.constant(
[[1],
[2],
[3],
[2],
[3],
[2],
[3],
[2]])
rows_tf = tf.reshape(rows_tf, shape=[-1, 1])
columns_tf = tf.reshape(
tf.tile(columns_tf, multiples=[1, 3]),
shape=[-1, 1])
sparse_indices = tf.reshape(
tf.concat([rows_tf, columns_tf], axis=-1),
shape=[-1, 2])
v = tf.gather_nd(input_tf, sparse_indices)
v = tf.reshape(v, [-1, 3])
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
#print 'rows\n', sess.run(rows_tf)
#print 'columns\n', sess.run(columns_tf)
print sess.run(v)
结果将是:
[[ 8.3355999 0. 8.45768547]
[ 0. 6.10318184 8.60233688]
[ 8.8973999 7.33056402 0. ]
[ 0. 3.89140368 5.82665682]
[ 8.8973999 0. 8.28397083]
[ 6.10318184 3.06143212 5.82665682]
[ 7.33056402 0. 8.28397083]
[ 6.10318184 3.89140368 0. ]]
推荐阅读
- python - 如何将python的结果提取到xls文件中
- java - 如何更改其数据的数量和内容
- python - 叠加图像质量随着实时视频源中的每一帧而降低
- autodesk-forge - IOS浏览器不在forge查看器中渲染模型
- sql - 将多个值关联到另一个表的单个值的最佳方法是什么?
- pyspark - 根据其他 2 列中的值向数据框添加新列(需要 Pyspark)
- r - 如何更改年份格式
- swift - SwiftRealm 如何检查领域对象是否有主键?
- visual-studio-code - VS Code 无法启动终端进程:CreateProcess failed
- java - 与 Docker 容器的行为差异