首页 > 解决方案 > 如何从张量流中的数据集类中获取 10K MNIST 图像的子集?

问题描述

我找到了以下在 tensorflow 中获取 mnist 数据集的方法:

def get_input_fn(dataset_split, batch_size, capacity=10000, min_after_dequeue=3000):

  def _input_fn():
    images_batch, labels_batch = tf.train.shuffle_batch(
        tensors=[dataset_split.images, dataset_split.labels.astype(np.int32)],
        batch_size=batch_size,
        capacity=capacity,
        min_after_dequeue=min_after_dequeue,
        enqueue_many=True,
        num_threads=4)
    features_map = {'images': images_batch}
    return features_map, labels_batch

  return _input_fn

    data = tf.contrib.learn.datasets.mnist.load_mnist()

    train_input_fn = get_input_fn(data.train, batch_size=256)
    eval_input_fn = get_input_fn(data.validation, batch_size=5000)

数据变量是数据集对象。这种方法对我来说很不清楚,我无法弄清楚如何将 60K 数据集转换为 10K 数据集。

当我执行以下操作时:

data = tf.contrib.learn.datasets.mnist.load_mnist().take(10000)

我得到错误:

AttributeError: 'Datasets' object has no attribute 'take'

但是文档提供了这种方法: 在此处输入图像描述

谢谢你的帮助!

标签: pythontensorflow

解决方案


contrib 模块中的此功能已弃用。您可以使用tf.keras.datasets.mnist.load_data(). 根据https://www.tensorflow.org/api_docs/python/tf/keras/datasets/mnist/load_data,它返回

Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. 

因此,为了对其应用任何功能,您需要将其加载到数据集对象中。

train, test = tf.keras.datasets.mnist.load_data(path='mnist.npz')
dataset_train = tf.data.Dataset.from_tensor_slices((train[0], train[1]))
dataset_test = tf.data.Dataset.from_tensor_slices((test[0], test[1]))

然后,您可以将 shuffle、batch、take 或任何映射函数dataset_train应用于dataset_test对象


推荐阅读