首页 > 解决方案 > 用于 TensorFlow 对象检测的平衡数据集

问题描述

我目前想使用 Tensorflows 对象检测 API 来解决我的自定义问题。我已经创建了数据集,但它非常不平衡。数据集有 3 个类,我的主要问题是,一个类有大约 16k 个样本,而另一个类只有大约 2.5k 个样本。

所以我认为我必须平衡数据集。有人告诉我,有一种叫做样本/类权重的东西(不确定这是否 100% 正确),它可以平衡训练样本,因此最大的类对训练的影响小于最小的类。

我找不到这种平衡方法。有人可以给我一个提示从哪里开始吗?

谢谢!

标签: tensorflowdatasetobject-detection

解决方案


你可以做正常的交叉熵,给你一个?x 1 张量,X 损失

如果你想让班级数 N 多计算 T 次,你可以这样做

X = X * tf.reduce_sum(tf.multiply(one_hot_label, class_weight), axis = 1)

tf.multiply

按您想要的任何重量缩放标签,

tf.reduce_sum

将标签向量 a 转换为标量,因此您最终得到 a ? x 1 张量填充了类权重。然后,您只需将损失的张量乘以权重的张量即可获得所需的结果。

由于一类比另一类多 6.4 倍,我将权重 1 和 6.4 分别应用于更常见和不太常见的类。这意味着每次出现较不常见的类时,它的影响是较常见的类的 6.4 倍,所以就像从每个类中看到相同数量的样本一样。

您可能需要对其进行修改,以使权重加起来等于类的数量。这匹配默认情况是所有权重都是 1。在这种情况下,我们有 1 /7.4 和 6.4/7.4


推荐阅读