首页 > 解决方案 > Tensorflow 数据集如何获取数据生成器的形状?

问题描述

考虑从 tensorflow 数据集中加载以下数据集

(ds_train, ds_test), ds_info= tfds.load('mnist', split=['train', 'test'],
                                        shuffle_files=True,
                                        as_supervised=True,with_info=True)

不过,该网站称

#https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_generator
#Warning: SOME ARGUMENTS ARE DEPRECATED: (output_shapes, output_types). They will be removed in a future version. 
#Instructions for updating: Use output_signature instead

但没有一个

ds_train.output_shapes
ds_train.output_types
ds_train.output_signature

正在工作

这里提到了一个类似的问题 # https://github.com/tensorflow/datasets/issues/102,所以现在只有临时修复

shape_of_data=tf.compat.v1.data.get_output_shapes(ds_train)

正在工作,它返回

(TensorShape([None, 28, 28, 1]), TensorShape([None]))

另一个更新的函数正在工作,但无法将 TensorShape 从参数中取出

tf.data.DatasetSpec(ds_train) 

回来

DatasetSpec(<_OptionsDataset shapes: ((28, 28, 1), ()), types: (tf.uint8, tf.int64)>, TensorShape([]))

无法分配。

获得生成器/迭代器形状的更新函数或属性是什么?

标签: pythontensorflowtensorflow-datasets

解决方案


您的变量ds_info具有以下信息:

height, width, channels = ds_info.features['image'].shape

像这样看:

ds_info.features['image']
Image(shape=(28, 28, 1), dtype=tf.uint8)

推荐阅读