首页 > 解决方案 > 如何在 TensorFlow 中访问数据集的特征字典

问题描述

使用 tensorflow-datasets 我将 MNIST 数据集集成到 Tensorflow 中,现在想用 Matplotlib 可视化单个图像。我是按照本指南做的:https ://www.tensorflow.org/datasets/overview

不幸的是,我在执行过程中收到一条错误消息。但它在指南中效果很好。

根据指南,您必须使用 take() 函数创建一个只有一张图像的新数据集。然后在指南中访问这些功能。在我尝试期间,我总是收到一条错误消息。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import matplotlib.pyplot as plt
import numpy as np
import tensorflow.compat.v1 as tf

import tensorflow_datasets as tfds



mnist_train, info = tfds.load(name="mnist", split=tfds.Split.TRAIN, with_info=True)
assert isinstance(mnist_train, tf.data.Dataset)

mnist_example = mnist_train.take(50)

#The error is raised in the next line. 
image = mnist_example["image"]
label = mnist_example["label"]

plt.imshow(image.numpy()[:, :, 0].astype(np.float32), cmap=plt.get_cmap("gray"))
print("Label: %d" % label.numpy())

这是错误消息:

Traceback (most recent call last):
  File "D:/mnist/model.py", line 24, in <module>
    image = mnist_example["image"]
TypeError: 'DatasetV1Adapter' object is not subscriptable

有谁知道我该如何解决这个问题?经过大量研究,我仍然没有找到解决方案。

标签: python-3.xtensorflow

解决方案


急切的执行

写代码 tf.enable_eager_execution()

为什么?

因为如果您不这样做,您将需要创建图表并session.run()获取一些样本

急切执行定义(参考):

TensorFlow 的 Eager Execution 是一种命令式编程环境,它立即评估 > 操作,无需构建图:操作返回具体值 > 而不是构建计算图以供稍后运行

然后

如何访问 Dataset 对象中的样本

您只需要遍历 DatasetV1Adapter 对象

通过转换为 numpy 访问一些样本的几种方法:

1.

mnist_example = mnist_train.take(50)
for sample in mnist_example:
    image, label = sample["image"].numpy(), sample["label"].numpy()
    plt.imshow(image[:, :, 0].astype(np.uint8), cmap=plt.get_cmap("gray"))
    plt.show()
    print("Label: %d" % label)

2.

mnist_example = tfds.as_numpy(mnist_example, graph=None)
for sample in mnist_example:
    image, label = sample["image"], sample["label"]
    plt.imshow(image[:, :, 0].astype(np.uint8), cmap=plt.get_cmap("gray"))
    plt.show()
    print("Label: %d" % label)

注意 1:如果您想要一个 numpy 数组中的所有 50 个样本,您可以创建一个空数组,例如np.zeros((28, 28, 50), dtype=np.uint8)数组,并将这些图像分配给它的元素。

注意2:为了imshow的目的,不要转换成np.float32,它没用,图像是uint8格式/范围(默认不归一化)


推荐阅读