首页 > 解决方案 > tensorflow中无法获取数据集的shape属性值?

问题描述

您可以检查以下代码,

import tensorflow as tf
data = tf.data.Dataset.range(10)
tf.print(data)

输出是

<RangeDataset shapes: (), types: tf.int64>

形状是空的。

标签: tensorflowtensorflow2.0

解决方案


就像 python 一样rangeDataset.range也不返回实际值。相反,它返回一个类似于生成器的对象,称为RangeDataset. 要获得一个 numpy 迭代器,您需要RangeDataset.as_numpy_iterator. 然后,您可以将其转换为列表,就像使用list(range(5))

>>> tf.data.Dataset.range(5)
<RangeDataset shapes: (), types: tf.int64>

>>> list(tf.data.Dataset.range(5).as_numpy_iterator())
[0, 1, 2, 3, 4]

>>> range(5)
range(0, 5)

>>> list(range(5))
[0, 1, 2, 3, 4]

有关其用法的更多示例,您可以查看文档


推荐阅读