首页 > 解决方案 > 查找二元分类的阈值

问题描述

我有一个有序数值数组和一个相应的类数组,以是/否的形式。我需要找到具有此标准的阈值,引用我正在研究的论文:“如果它落在两个不属于同一类的连续示例之间,则 T 是一个阈值。在特殊情况下,当一组两个或更多的样本具有相同的值但属于多个类,则样本两侧的切点也是阈值。具有相同值的样本不能分开。

如果我理解正确,如果我有:

vals = [10,12, 22, 28, 28, 40, 41]
classes = ['y','y','n','y','n','y','n']

阈值必须是:[17,25,34,40.5]

这是我写的代码:

    for i in range(len(vals)-1):
        if vals[i] != vals[i+1]:
            if classes[i] != classes[i+1]:
                thresholds.append((vals[i] + vals[i+1]) / 2)
        else:
            j = i
            while vals[i] == vals[i+1]:
                i = i+1
            if j != 0:
                thresholds.append((vals[j] + vals[j-1]) / 2)
            thresholds.append((vals[i] + vals[i+1]) / 2)

但是 1)我真的不喜欢它,我希望它更紧凑,2)即使它适用于示例之前它并不总是正确的,例如,如果我有

vals = [2,2,5,5,7,11,18] 
out = ['y','y','y','y','n','n','n]

我希望唯一的阈值是[6],但此代码也会打印3.5

我怎样才能使这个更漂亮和更通用?

标签: python

解决方案


更新:
这是目前的新代码(可以进一步重构,我将在下一次编辑中发布)。我已经在大量测试用例上对其进行了测试。下面只是代码,您可以在这些链接中找到带有注释和单元测试的详细代码:

纯代码:

def compress_group(vls):
    val0, lab0 = next(vls[1])
    if all(lab == lab0 for val, lab in vls[1]):
        return val0, lab0
    return val0, case2_label

vl_combined = [(v, l) for v, l in zip(values, labels)]
vl_groups = groupby(vl_combined, lambda vc: vc[0])
vl_groups = map(lambda vl_group: compress_group(vl_group), vl_groups)

thresholds = []
prev_value, prev_label = next(vl_groups)
for curr_value, curr_label in vl_groups:
    if prev_label == case2_label or curr_label != prev_label:
        threshold = (prev_value + curr_value) / 2
        thresholds.append(threshold)
    prev_value, prev_label = curr_value, curr_label

return thresholds

推荐阅读