首页 > 解决方案 > 使用 tf.map_fn() 迭代张量时,Tensorflow 获取维度信息

问题描述

假设我有一个tsshape张量[s1, s2, s3],我想用tf.map_fn如下方式迭代它:

tf.map_fn(lambda dim1:
    tf.map_fn(lambda dim2:
        do_sth(dim, idx1, idx2)
    ,dim)
,ts)

idx1以上idx2是维度0的索引和维度1的索引。tsdo_sth()怎样才能得到这些?我想得到它,就好像我在做类似的事情:

for idx1 in range(s1):
    for idx2 in range(s2):
        tensor = ts[idx1][idx2]
        do_sth(tensor, idx1, idx2)

我不能那样做的原因是大部分时间s1, s2, s3都是未知的(即ts形状(?, ?, t3)或相似)

那可能吗?

标签: pythontensorflow

解决方案


我建议您在最后添加一个额外的维度并在运行命令之前用索引填充它。

此代码(已测试)将索引添加到 X 和 Y 索引分别乘以 10 和 100 的值:

from __future__ import print_function
import tensorflow as tf
import numpy as np

r = 3

a = tf.reshape( tf.constant( range( r * r ) ), ( r, r ) )
x = tf.tile( tf.cast( tf.lin_space( 0.0, r - 1, r )[ None, : ], tf.int32 ), [ r, 1 ] )
y = tf.tile( tf.cast( tf.lin_space( 0.0, r - 1, r )[ :, None ], tf.int32 ), [ 1, r ] )

b = tf.stack( [ a, x, y ], axis = -1 )

c = tf.map_fn( lambda y: tf.map_fn( lambda x: 
        x[ 0 ] + 10 * x[ 1 ] + 100 * x[ 2 ]
    , y ), b )

with tf.Session() as sess:
    res = sess.run( [ c ] )
    for x in res:
        print()
        print( x )

输出:

[[ 0 11 22]
[103 114 125]
[206 217 228]]


推荐阅读