首页 > 解决方案 > ImageDataGenerator .next() 这可以等同于 mnist.train.next_batch 吗?

问题描述

大家好,我是 tensorflow 的新手,我正在研究胶囊网络,我的数据集如下所示:

 -train
     |-class1
     |-class2
     |-class3

我正在使用 ImageDataGenerator flow_from_directory 从目录名称生成标签:

 train = ImageDataGenerator(rescale= 1/255)
trainData = train.flow_from_directory(dir_path , class_mode='categorical' , target_size=(80, 80), batch_size=batch_size )

这是标签的占位符:

 y = tf.placeholder(shape=[None], dtype=tf.int64, name="y")

我正在关注的教程是在 mnist 数据集上工作,所以他这样做是为了将训练数据提供给模型:

X_batch, y_batch = mnist.train.next_batch(batch_size)
        

我这样做了:

 X_batch, y_batch = trainData.next()

它适用于图像,但不适用于 labesl,因为我收到了这个错误:

Cannot feed value of shape (10, 3) for Tensor 'y:0', which has shape '(?,)'

y_batch.shape --> (10, 3) 10 是 batch_size 我猜 3 是类的数量,我对此感到困惑

任何人都可以帮忙吗?

标签: pythontensorflowdeep-learning

解决方案


占位符创建空然后加载数据。对于标签不需要创建占位符


推荐阅读