python - How to get values in TensorArray which contain vary shape tensors
问题描述
I get a TensorArray that contain a list of vary shape tensors through tf.while_loop()
, but I don't know how to get them as a normal list with tensors.
For example:
TensorArray([[1,2], [1,2,3], ...]) -> [Tensor([1,2]), Tensor([1,2,3]), ...]
res = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True, infer_shape=False)
res = res.write(0, (1, 2))
res = res.write(0, (1, 2, 3))
with tf.Session() as sess:
print sess.run(res.stack())
I get the error message in sess.run(res.stack())
TensorArray has inconsistent shapes. Index 0 has shape: [2] but index 1 has shape: [3]
解决方案
In general, you cannot make a list of the tensors in a tensor array because its size is only known on graph execution. However, if you know the size in advance, you can just make a list of the read operations yourself:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
res = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True, infer_shape=False)
res = res.write(0, (1, 2))
res = res.write(1, (1, 2, 3))
print(res.size()) # Value only known on graph execution
# Tensor("TensorArraySizeV3:0", shape=(), dtype=int32)
# Can make a list if the size is known in advance
tensors = [res.read(i) for i in range(2)]
print(tensors)
# [<tf.Tensor 'TensorArrayReadV3:0' shape=<unknown> dtype=int32>, <tf.Tensor 'TensorArrayReadV3_1:0' shape=<unknown> dtype=int32>]
print(sess.run(tensors))
# [array([1, 2]), array([1, 2, 3])]
Otherwise, you can still use a while loop to iterate the tensor array. For example, you can print its contents like this:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
res = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True, infer_shape=False)
res = res.write(0, (1, 2))
res = res.write(1, (1, 2, 3))
def loop_body(i, res):
# Must import the following in Python 2:
# from __future__ import print_function
with tf.control_dependencies([tf.print(res.read(i))]):
return i + 1, res
i, res = tf.while_loop(
lambda i, res: i < res.size(),
loop_body,
(tf.constant(0, tf.int32), res))
print(sess.run(i))
# [1 2]
# [1 2 3]
# 2
推荐阅读
- html - 您如何使用 Chrome 开发人员工具找出网页上出现空白的原因?
- r - 如何在R中具有水平填充区域图
- javascript - jquery:发票中所有输入字段的总和
- lazarus - Lazarus:堆叠图像,绘画顺序
- javascript - 遍历嵌套数组,按索引对对象进行分组
- python-3.x - 从 csv 获取数据并将相应的列放入字典
- javascript - 推送后数组保持为空
- php - 在另一个 MySQL 查询中使用一个 MySQL 查询的结果作为变量
- makefile - Makefile 不执行多个目标
- centos7 - 如何在 centos 7 上获取 python3 pypy