首页 > 解决方案 > tensorflow 嵌套 map_fn 连接两个张量

问题描述

说我有两个张量:

a=Tensor("zeros_3:0", shape=(2, 4, 5), dtype=float32)
b=Tensor("ones_3:0", shape=(2, 3, 5), dtype=float32)

如何(2,3,4,10)使用嵌套的 map_fn 或其他 tf 函数沿轴 2 连接每个元素以获得新的张量形状?

这是我的 for 循环版本

        concat_list = []
        for i in range(a.get_shape()[1]):
            for j in range(b.get_shape()[1]):
                concat_list.append(tf.concat([a[:, i, :], b[:, j, :]], axis=1))

使用“新单位维度”有一个类似的问题,但我不知道如何使用tf.concat“新单位维度”。

标签: tensorflownested

解决方案


您可以使用tf.tile和。一个例子:tf.expand_dimstf.concat

import tensorflow as tf

a = tf.random_normal(shape=(2,4,5),dtype=tf.float32)
b = tf.random_normal(shape=(2,3,5),dtype=tf.float32)

# your code
concat_list = []
for i in range(a.get_shape()[1]):
    for j in range(b.get_shape()[1]):
        concat_list.append(tf.concat([a[:, i, :], b[:, j, :]], axis=1))

# Application  method
A = tf.tile(tf.expand_dims(a,axis=1),[1,b.shape[1],1,1])
B = tf.tile(tf.expand_dims(b,axis=2),[1,1,a.shape[1],1])
result = tf.concat([A,B],axis=-1)

with tf.Session() as sess:
    concat_list_val,result_val = sess.run([concat_list,result])
    print(concat_list_val[-1])
    print(result_val.shape)
    print(result_val[:,-1,-1,:])

# your result
[[ 1.0459949   1.5562199  -0.04387079  0.17898582 -1.9795663   0.988437
  -0.40415847  0.8865694  -1.4764767  -0.8417388 ]
 [-0.3542176  -0.3281141   0.01491702  0.91899025 -1.0651684   0.12315683
   0.6555444  -0.80451876 -1.3260773   0.33680603]]
# Application result shape
(2, 3, 4, 10)
# Application result 
[[ 1.0459949   1.5562199  -0.04387079  0.17898582 -1.9795663   0.988437
  -0.40415847  0.8865694  -1.4764767  -0.8417388 ]
 [-0.3542176  -0.3281141   0.01491702  0.91899025 -1.0651684   0.12315683
   0.6555444  -0.80451876 -1.3260773   0.33680603]]

表现

您可以使用以下代码来比较速度。

import datetime
...

with tf.Session() as sess:
    start = datetime.datetime.now()
    print('#' * 60)
    for i in range(10000):
        result_val = sess.run(result)
    end = datetime.datetime.now()
    print('cost time(seconds) : %.2f' % ((end - start).total_seconds()))

    start = datetime.datetime.now()
    print('#' * 60)
    for i in range(10000):
        concat_list_val = sess.run(concat_list)
    end = datetime.datetime.now()
    print('cost time(seconds) : %.2f' % ((end - start).total_seconds()))

向量化方法 10000 次迭代和1.48s循环 10000 次迭代5.76s在我的 8GB GPU 内存上进行。但是矢量化方法需要并且循环时间是when和。a.shape=(2,4,5)b.shape=(2,3,5)3.28s317.23sa.shape=(20,40,5)b.shape=(20,40,5)

矢量化方法将明显快于tf.map_fn()和 python 循环。


推荐阅读