python - 如何从张量流中的数据集类中获取 10K MNIST 图像的子集?
问题描述
我找到了以下在 tensorflow 中获取 mnist 数据集的方法:
def get_input_fn(dataset_split, batch_size, capacity=10000, min_after_dequeue=3000):
def _input_fn():
images_batch, labels_batch = tf.train.shuffle_batch(
tensors=[dataset_split.images, dataset_split.labels.astype(np.int32)],
batch_size=batch_size,
capacity=capacity,
min_after_dequeue=min_after_dequeue,
enqueue_many=True,
num_threads=4)
features_map = {'images': images_batch}
return features_map, labels_batch
return _input_fn
data = tf.contrib.learn.datasets.mnist.load_mnist()
train_input_fn = get_input_fn(data.train, batch_size=256)
eval_input_fn = get_input_fn(data.validation, batch_size=5000)
数据变量是数据集对象。这种方法对我来说很不清楚,我无法弄清楚如何将 60K 数据集转换为 10K 数据集。
当我执行以下操作时:
data = tf.contrib.learn.datasets.mnist.load_mnist().take(10000)
我得到错误:
AttributeError: 'Datasets' object has no attribute 'take'
谢谢你的帮助!
解决方案
contrib 模块中的此功能已弃用。您可以使用tf.keras.datasets.mnist.load_data()
. 根据https://www.tensorflow.org/api_docs/python/tf/keras/datasets/mnist/load_data,它返回
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
因此,为了对其应用任何功能,您需要将其加载到数据集对象中。
train, test = tf.keras.datasets.mnist.load_data(path='mnist.npz')
dataset_train = tf.data.Dataset.from_tensor_slices((train[0], train[1]))
dataset_test = tf.data.Dataset.from_tensor_slices((test[0], test[1]))
然后,您可以将 shuffle、batch、take 或任何映射函数dataset_train
应用于dataset_test
对象
推荐阅读
- java - spring.datasource.initialize 已弃用
- python - 在 Kivy 中实现计时器时遇到问题
- python - 在 Pycharm 中无法添加 conda 环境
- java - @PostConstruct 注解的方法不会在 springboot 应用程序启动时启动
- json - 将 Json 嵌套到 CSV PowerShell
- android - 需要从 Aadhar 扫描数据中修剪 xml 内容
- node.js - 分配具有唯一 ID 的 WebSocket 和 net.Socket
- c# - 使用 LINQ 按键收集选择数据的最快方法是什么?
- javascript - 如何提高 fetch api JSON 下载速度或加载图标?
- amazon-web-services - 如果 EKS 工作程序节点位于私有子网中,则 ELB 服务中断