首页 > 解决方案 > 如何在张量流中进行花式索引

问题描述

我有一个具有(假设)形状(5、3、5)的 Tensorflow 张量 A。我想得到一个形状为 (5, 3) 的张量 B 使得

# B = [A[0, :, 0], A[1, :, 1], A[2, :, 2], ...]

我想在不使用任何 for 循环的情况下实现此索引。使用 numpy one 可以:

import numpy as np
# A.shape = (5, 3, 5)
B = A[np.arange(A.shape[0]), :, np.arange(A.shape[2])]

任何建议如何使用 Tensorflow 做到这一点?

标签: pythonnumpytensorflow

解决方案


有两种方法可以实现您的目标。

import tensorflow as tf

a = tf.random_normal(shape=(5,3,5))

# method 1: take the diagonal after transpose
b_diag = tf.matrix_diag_part(tf.transpose(a,[1,0,2])) # shape = (3,5)
result1 = tf.transpose(b_diag,[1,0])

# method 2: take the value by indices
indices = tf.stack([tf.range(tf.shape(a)[0])]*2,axis=-1)
# [[0 0]
#  [1 1]
#  [2 2]
#  [3 3]
#  [4 4]]
result2 = tf.gather_nd(tf.transpose(a,[0,2,1]),indices)

with tf.Session() as sess:
    val_a,val_result1,val_result2 = sess.run([a,result1,result2])
    print('origin matrix:\n',val_a)
    print('method 1:\n',val_result1)
    print('method 2:\n',val_result2)

origin matrix:
 [[[ 0.6905094   0.13725948 -0.42244634 -0.19795062  0.02895796]
  [-1.2307093  -0.90263253  0.8939539   0.43943858  0.60205126]
  [ 0.1317933   0.7697048  -0.8040689  -0.41206598 -0.66366917]]

 [[-0.07341296 -0.83268213  1.1547179  -1.035854   -0.43292868]
  [ 0.63890094 -1.9335823  -0.61634874 -3.2909455  -1.1862688 ]
  [-1.0031502  -0.07485765  0.53183764  0.55050373 -0.03113765]]

 [[ 0.23482691 -0.9363624   0.30995724 -0.02038437  0.65965956]
  [ 0.73754835  0.23244548 -1.5190666   0.89143264 -0.47610378]
  [ 0.6452583   1.5191171  -0.15525642  0.5060588   1.2310679 ]]

 [[ 0.32281107  0.80718434 -0.865543    0.5899832  -0.66145474]
  [ 0.45294672 -0.31048244 -0.48481905 -1.1497563   1.4231541 ]
  [ 0.2343677  -0.8113462   0.58899856  1.6336825   0.11803629]]

 [[ 0.8602735   1.3486015   1.4897087  -1.2132328  -0.70290196]
  [-2.635646   -0.3950463   0.19890717 -1.9909118   1.3279002 ]
  [-0.88162804 -0.7264523  -0.40416357 -0.7689555   1.33081   ]]]
method 1:
 [[ 0.6905094  -1.2307093   0.1317933 ]
 [-0.83268213 -1.9335823  -0.07485765]
 [ 0.30995724 -1.5190666  -0.15525642]
 [ 0.5899832  -1.1497563   1.6336825 ]
 [-0.70290196  1.3279002   1.33081   ]]
method 2:
 [[ 0.6905094  -1.2307093   0.1317933 ]
 [-0.83268213 -1.9335823  -0.07485765]
 [ 0.30995724 -1.5190666  -0.15525642]
 [ 0.5899832  -1.1497563   1.6336825 ]
 [-0.70290196  1.3279002   1.33081   ]]

推荐阅读