首页 > 解决方案 > 如何通过沿某个轴交错来合并两个 3D 张量?

问题描述

假设我想通过沿一个特定轴交错来合并两个 3D 张量流张量 a 和 b。例如,张量 a 的形状为 (3,3,2),张量 b 的形状为 (3,2,2)。我想创建一个沿轴 1 交错的张量 c,从而产生形状为 (3,5,2) 的张量。

例子:

a = [[[1,1],[2,2],[3,3]],
     [[4,4],[5,5],[6,6]],
     [[7,7],[8,8],[9,9]]]

b = [[[10,10],[11,11]],
     [[12,12],[13,13]],
     [[14,14],[15,15]]]

c = [[[1,1],[10,10],[2,2],[11,11],[3,3]],
     [[4,4],[12,12],[5,5],[13,13],[6,6]],
     [[7,7],[14,14],[8,8],[15,15],[9,9]]]

标签: pythontensorflow

解决方案


您可以先重新排序列的索引。

import tensorflow as tf

a = [[[1,1],[2,2],[3,3]],
     [[4,4],[5,5],[6,6]],
     [[7,7],[8,8],[9,9]]]

b = [[[10,10],[11,11]],
     [[12,12],[13,13]],
     [[14,14],[15,15]]]

a_tf = tf.constant(a)
b_tf = tf.constant(b)

a_tf_column = tf.range(a_tf.shape[1])*2
# [0 2 4]
b_tf_column = tf.range(b_tf.shape[1])*2+1
# [1 3]

column_indices = tf.concat([a_tf_column,b_tf_column],axis=-1)
# Before TF v1.13
column_indices = tf.contrib.framework.argsort(column_indices)
## From TF v1.13
# column_indices = tf.argsort(column_indices)

# [0 3 1 4 2]

然后你应该为tf.gather_nd().

column,row = tf.meshgrid(column_indices,tf.range(a_tf.shape[0]))
combine_indices = tf.stack([row,column],axis=-1)
# [[[0,0],[0,3],[0,1],[0,4],[0,2]],
#  [[1,0],[1,3],[1,1],[1,4],[1,2]],
#  [[2,0],[2,3],[2,1],[2,4],[2,2]]]

最后,您应该连接 and 的值ab使用tf.gather_nd()来获得结果。

combine_value = tf.concat([a_tf,b_tf],axis=1)
result = tf.gather_nd(combine_value,combine_indices)

with tf.Session() as sess:
    print(sess.run(result))

# [[[1,1],[10,10],[2,2],[11,11],[3,3]],
#  [[4,4],[12,12],[5,5],[13,13],[6,6]],
#  [[7,7],[14,14],[8,8],[15,15],[9,9]]]

推荐阅读