首页 > 解决方案 > tensorflow concat 与转置使用广播语义

问题描述

v1v2具有相同的形状。在张量流中是否可以连接v1和使用广播语义的转置版本?v2

例如,

v1 = tf.constant([[1,1,1,1],[3,3,3,3],[5,5,5,5]])
v2 = tf.constant([[2,2,2,2],[4,4,4,4]])

我想生产类似的东西

[
 [[[1,1,1,1], [2,2,2,2]],
  [[1,1,1,1], [4,4,4,4]]],
 [[[3,3,3,3], [2,2,2,2]],
  [[3,3,3,3], [4,4,4,4]]],
 [[[5,5,5,5], [2,2,2,2]],
  [[5,5,5,5], [4,4,4,4]]]]

也就是说,使用v1as[3, 4]v2as [2,4],我想做

tf.concat([v1, tf.transpose(v2)], axis=0)

并产生一个 [3,2,2,4]矩阵。

这样做有什么诀窍吗?

标签: tensorflowarray-broadcasting

解决方案


如果你的意思是欺骗一个优雅的解决方案,我不这么认为。但是,一个可行的解决方案是平铺并重复传入的 v1、v2

import tensorflow as tf

v1 = tf.constant([[1, 1, 1, 1],
                  [3, 3, 3, 3],
                  [7, 7, 7, 7],
                  [5, 5, 5, 5]])
v2 = tf.constant([[2, 2, 2, 2],
                  [6, 6, 6, 6],
                  [4, 4, 4, 4]])


def my_concat(v1, v2):
    v1_m, v1_n = v1.shape.as_list()
    v2_m, v2_n = v2.shape.as_list()

    v1 = tf.concat([v1 for i in range(v2_m)], axis=-1)
    v1 = tf.reshape(v1, [v2_m * v1_m, -1])

    v2 = tf.tile(v2, [v1_m, 1])
    v1v2 = tf.concat([v1, v2], axis=-1)

    return tf.reshape(v1v2, [v1_m, v2_m, 2, v2_n])


with tf.Session() as sess:
    ret = sess.run(my_concat(v1, v2))

    print ret.shape
    print ret

推荐阅读