首页 > 解决方案 > How to get values in TensorArray which contain vary shape tensors

问题描述

I get a TensorArray that contain a list of vary shape tensors through tf.while_loop(), but I don't know how to get them as a normal list with tensors.

For example:

TensorArray([[1,2], [1,2,3], ...]) -> [Tensor([1,2]), Tensor([1,2,3]), ...]
res = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True, infer_shape=False)
res = res.write(0, (1, 2))
res = res.write(0, (1, 2, 3))
with tf.Session() as sess:                                                        
     print sess.run(res.stack())

I get the error message in sess.run(res.stack())

TensorArray has inconsistent shapes. Index 0 has shape: [2] but index 1 has shape: [3]

标签: pythontensorflow

解决方案


In general, you cannot make a list of the tensors in a tensor array because its size is only known on graph execution. However, if you know the size in advance, you can just make a list of the read operations yourself:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    res = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True, infer_shape=False)
    res = res.write(0, (1, 2))
    res = res.write(1, (1, 2, 3))
    print(res.size()) # Value only known on graph execution
    # Tensor("TensorArraySizeV3:0", shape=(), dtype=int32)
    # Can make a list if the size is known in advance
    tensors = [res.read(i) for i in range(2)]
    print(tensors)
    # [<tf.Tensor 'TensorArrayReadV3:0' shape=<unknown> dtype=int32>, <tf.Tensor 'TensorArrayReadV3_1:0' shape=<unknown> dtype=int32>]
    print(sess.run(tensors))
    # [array([1, 2]), array([1, 2, 3])]

Otherwise, you can still use a while loop to iterate the tensor array. For example, you can print its contents like this:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    res = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True, infer_shape=False)
    res = res.write(0, (1, 2))
    res = res.write(1, (1, 2, 3))
    def loop_body(i, res):
        # Must import the following in Python 2:
        # from __future__ import print_function
        with tf.control_dependencies([tf.print(res.read(i))]):
            return i + 1, res
    i, res = tf.while_loop(
        lambda i, res: i < res.size(),
        loop_body,
        (tf.constant(0, tf.int32), res))
    print(sess.run(i))
    # [1 2]
    # [1 2 3]
    # 2

推荐阅读