python - 无法从 flow_from_dataframe 进行训练得到了意外没有。类
问题描述
我将在标签位于 csv 文件中的一组图像上训练一个模型。所以我使用flow_from_dataframe from tf.keras
并指定了参数,但class_mode
说到错误并说Found 3662 validated image filenames belonging to 1 classes.
- 稀疏和分类。这是多类分类。”
“最初标签是 int,所以我将它转换为字符串,然后我得到了这个输出。”
df_train=pd.read_csv(r"../input/train.csv",delimiter=',')
df_test=pd.read_csv(r"../input/test.csv",delimiter=',')
print(df_train.head())
print(df_test.head())
df_train['id_code']=df_train['id_code']+'.png'
df_train['diagnosis']=str(df_train['diagnosis'])
df_test['id_code']=df_test['id_code']+'.png'
""" output is
id_code diagnosis
0 000c1434d8d7 2
1 001639a390f0 4
2 0024cdab0c1e 1
3 002c21358ce6 0
4 005b95c28852 0
id_code
0 0005cfc8afb6
1 003f0afdcd15
2 006efc72b638
3 00836aaacf06
4 009245722fa4
"""
train_datagen = ImageDataGenerator(
rescale = 1./255,
rotation_range=30,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')
TRAINING_DIR='../input/train_images'
train_generator= train_datagen.flow_from_dataframe(
dataframe=df_train,
directory=TRAINING_DIR,
x_col='id_code',
y_col='diagnosis',
batch_size=20,
target_size=(1050,1050),
class_mode='categorical'#used also sparsed
)
""" output is
Found 3662 validated image filenames belonging to 1 classes.
"""
“我期望输出为"Found 3662 validated image filenames belonging to 5 classes"
,但实际输出为"Found 3662 validated image filenames belonging to 1 classes"
”</p>
解决方案
“稀疏”类模式需要整数值,而“分类”需要类列的一个热编码向量。所以我会尝试:
df['diagnosis'] = df['diagnosis'].astype(str)
然后使用“稀疏”类模式。
train_generator= train_datagen.flow_from_dataframe(
dataframe=df_train,
directory=TRAINING_DIR,
x_col='id_code',
y_col='diagnosis',
batch_size=20,
target_size=(1050,1050),
class_mode='sparse'
)
或者,您可以使用这样的一种热编码:
pd.get_dummies(df,prefix=['diagnosis'], drop_first=True)
然后使用“分类”class_mode:
train_generator= train_datagen.flow_from_dataframe(
dataframe=df_train,
directory=TRAINING_DIR,
x_col='id_code',
y_col=df.columns[1:],
batch_size=20,
target_size=(1050,1050),
class_mode='categorical'
)
推荐阅读
- python - 有没有一种简单的方法可以从 numpy.dtype.descr 中删除“填充”字段?
- linux - 单元测试卡在 Jenkins 构建中
- firebase - Firestore :仅更新键值对
- php - 如何检测内部路径?
- javascript - Fabric.js 对象:修改了几个选定的对象行为
- excel - 循环复制 10 行范围时,我需要在每个新副本中添加一个单元格,使其值增加 1
- configuration - 为什么威胁搜索索引不接收来自 sysmon(windows 索引)的映射数据?
- c - 每秒打印数字
- c# - 反射而不是 .NET 到 .NET 方案
- gradle - gradle - 将 demo.jar 文件部署到服务器