machine-learning - 在 Keras 中为 TF 数据集中的 one-hot 编码标签指定类或样本权重
问题描述
我正在尝试在不平衡的训练集上训练图像分类器。为了应对类不平衡,我想对类或单个样本进行加权。加权类似乎不起作用。不知何故,对于我的设置,我无法找到指定样本权重的方法。您可以在下面阅读我如何加载和编码训练数据以及我尝试的两种方法。
训练数据加载和编码
我的训练数据存储在一个目录结构中,其中每个图像都放置在与其类对应的子文件夹中(我总共有 32 个类)。由于训练数据太大,一次全部加载到内存中,我使用 image_dataset_from_directory 并通过它描述TF Dataset中的数据:
train_ds = keras.preprocessing.image_dataset_from_directory (training_data_dir,
batch_size=batch_size,
image_size=img_size,
label_mode='categorical')
我使用 label_mode 'categorical',以便将标签描述为one-hot 编码向量。
然后我预取数据:
train_ds = train_ds.prefetch(buffer_size=buffer_size)
方法 1:指定类权重
在这种方法中,我尝试通过 fit 的 class_weight 参数指定类的类权重:
model.fit(
train_ds, epochs=epochs, callbacks=callbacks, validation_data=val_ds,
class_weight=class_weights
)
对于每个类,我们计算与该类的训练样本数成反比的权重。这是按如下方式完成的(这是在上述 train_ds.prefetch() 调用之前完成的):
class_num_training_samples = {}
for f in train_ds.file_paths:
class_name = f.split('/')[-2]
if class_name in class_num_training_samples:
class_num_training_samples[class_name] += 1
else:
class_num_training_samples[class_name] = 1
max_class_samples = max(class_num_training_samples.values())
class_weights = {}
for i in range(0, len(train_ds.class_names)):
class_weights[i] = max_class_samples/class_num_training_samples[train_ds.class_names[i]]
我不确定这个解决方案是否有效,因为 keras 文档没有指定 class_weights 字典的键,以防标签是一次性编码的。我尝试以这种方式训练网络,但发现权重对生成的网络没有真正的影响:当我查看每个单独类的预测类分布时,我可以识别整个训练集的分布,其中对于每个类别,最有可能预测占主导地位的类别。在没有指定任何类别权重的情况下运行相同的训练会导致相似的结果。所以我怀疑权重似乎对我的情况没有影响。
这是因为指定类权重不适用于一次性编码标签,还是因为我可能做错了其他事情(在我没有在这里显示的代码中)?
方法2:指定样本权重
作为提出不同(在我看来不太优雅)解决方案的尝试,我想通过 fit 方法的 sample_weight 参数指定单个样本权重。但是从文档中我发现:
[...] 当 x 是数据集、生成器或 keras.utils.Sequence 实例时,不支持此参数,而是提供 sample_weights 作为 x 的第三个元素。
在我的设置中确实是这种情况,其中 train_ds 是一个数据集。现在我真的很难找到可以从中得出如何修改 train_ds 的文档,这样它就有了带有权重的第三个元素。我认为使用数据集的 map 方法可能很有用,但我想出的解决方案显然无效:
train_ds = train_ds.map(lambda img, label: (img, label, class_weights[np.argmax(label)]))
有没有人有一个可以与加载的数据集结合使用的解决方案image_dataset_from_directory
?
解决方案
推荐阅读
- sockets - 这段代码中的权重在哪里更新?
- r - 使用非标准评估按多列排序
- html - 灯箱的奇怪行为
- json - 如何修复convertapi中的“状态=内部服务器错误,状态代码=500”错误,以便在顶点中合并pdf?
- kubernetes - HPA labelSelector 不过滤外部指标
- php - Htaccess 将查询字符串作为目录解释为脚本
- java - 在android studio中使用线程崩溃Java apk
- python - 在第一个数字后使用条件拆分列
- php - 一旦我的表单位于具有实际服务器(Heroku)的托管站点上,React 会通过 PHP 提交我的表单,还是我需要使用不同的代码?
- webpack - 如何使用 Webpack 4 加载图像?