python - 如何在张量流中进行花式索引
问题描述
我有一个具有(假设)形状(5、3、5)的 Tensorflow 张量 A。我想得到一个形状为 (5, 3) 的张量 B 使得
# B = [A[0, :, 0], A[1, :, 1], A[2, :, 2], ...]
我想在不使用任何 for 循环的情况下实现此索引。使用 numpy one 可以:
import numpy as np
# A.shape = (5, 3, 5)
B = A[np.arange(A.shape[0]), :, np.arange(A.shape[2])]
任何建议如何使用 Tensorflow 做到这一点?
解决方案
有两种方法可以实现您的目标。
import tensorflow as tf
a = tf.random_normal(shape=(5,3,5))
# method 1: take the diagonal after transpose
b_diag = tf.matrix_diag_part(tf.transpose(a,[1,0,2])) # shape = (3,5)
result1 = tf.transpose(b_diag,[1,0])
# method 2: take the value by indices
indices = tf.stack([tf.range(tf.shape(a)[0])]*2,axis=-1)
# [[0 0]
# [1 1]
# [2 2]
# [3 3]
# [4 4]]
result2 = tf.gather_nd(tf.transpose(a,[0,2,1]),indices)
with tf.Session() as sess:
val_a,val_result1,val_result2 = sess.run([a,result1,result2])
print('origin matrix:\n',val_a)
print('method 1:\n',val_result1)
print('method 2:\n',val_result2)
origin matrix:
[[[ 0.6905094 0.13725948 -0.42244634 -0.19795062 0.02895796]
[-1.2307093 -0.90263253 0.8939539 0.43943858 0.60205126]
[ 0.1317933 0.7697048 -0.8040689 -0.41206598 -0.66366917]]
[[-0.07341296 -0.83268213 1.1547179 -1.035854 -0.43292868]
[ 0.63890094 -1.9335823 -0.61634874 -3.2909455 -1.1862688 ]
[-1.0031502 -0.07485765 0.53183764 0.55050373 -0.03113765]]
[[ 0.23482691 -0.9363624 0.30995724 -0.02038437 0.65965956]
[ 0.73754835 0.23244548 -1.5190666 0.89143264 -0.47610378]
[ 0.6452583 1.5191171 -0.15525642 0.5060588 1.2310679 ]]
[[ 0.32281107 0.80718434 -0.865543 0.5899832 -0.66145474]
[ 0.45294672 -0.31048244 -0.48481905 -1.1497563 1.4231541 ]
[ 0.2343677 -0.8113462 0.58899856 1.6336825 0.11803629]]
[[ 0.8602735 1.3486015 1.4897087 -1.2132328 -0.70290196]
[-2.635646 -0.3950463 0.19890717 -1.9909118 1.3279002 ]
[-0.88162804 -0.7264523 -0.40416357 -0.7689555 1.33081 ]]]
method 1:
[[ 0.6905094 -1.2307093 0.1317933 ]
[-0.83268213 -1.9335823 -0.07485765]
[ 0.30995724 -1.5190666 -0.15525642]
[ 0.5899832 -1.1497563 1.6336825 ]
[-0.70290196 1.3279002 1.33081 ]]
method 2:
[[ 0.6905094 -1.2307093 0.1317933 ]
[-0.83268213 -1.9335823 -0.07485765]
[ 0.30995724 -1.5190666 -0.15525642]
[ 0.5899832 -1.1497563 1.6336825 ]
[-0.70290196 1.3279002 1.33081 ]]
推荐阅读
- scala - 如何使用scala在镶木地板文件中写入常量值?
- wordpress - 如何根据子域的 ID 显示 WordPress 帖子
- sql - 如何在同一查询中使用 rowNumber 获取 MAX rownumber
- javascript - 如何在递归方法中进行同步调用?
- vb.net - 删除 vb.net 中的时间/日期/小时
- swift - 有没有办法将泛型限制为 Swift 中的一种或另一种类型?
- typescript - 打字稿字符串到枚举转换错误
- python - Docker 正在构建但未在浏览器中显示数据
- bash - 从同一个 playbook 中获取 ansible playbook 的 PID
- if-statement - 乘法 IF 有问题