首页 > 解决方案 > TypeError: 'int' 对象在使用 model.fit() 时不可迭代

问题描述

我试图拟合我的模型,但遇到“类型错误”。这段代码的编写与 CNN 的 kaggle 源代码示例中描述的几乎相同。

https://www.kaggle.com/kanncaa1/convolutional-neural-network-cnn-tutorial/notebook

但是,不断出现类型错误。我改成model.fit_generatormodel. fit因为函数是从新版本的tensorflow更新而来的。我猜想shape[0]可能会导致这个与内部相关的问题。谁能帮忙指出下面代码中的错误类型?

# model fitting
# from tensorflow 2.1.0
history = model.fit(datagen.flow(train_x, train_y, batch_size = batch_size),
                          epochs = epochs,
                          validation_data = (valid_x, valid_y),
                          steps_per_epoch = train_x.shape[0] // batch_size)

标签: pythontensorflowdeep-learningconv-neural-network

解决方案


可能的解决方案:

运行print(x_train.shape),并确保除第一个维度之外的所有值都与input_shape您在模型定义期间指定的值匹配。如果您从 kaggle 内核移植代码,那么我假设您的模型定义为:

model.add(Conv2D(filters = 8, kernel_size = (5,5),padding = 'Same', 
                 activation ='relu', input_shape = (28,28,1)))

确保x_train.shape符合规定input_shape


推荐阅读