首页 > 技术文章 > 4、K近邻(k-Nearest Neighbor,KNN)算法——监督、分类

ai-learning-blogs 2019-08-15 11:37 原文

1、K近邻(k-Nearest Neighbor,KNN)算法

K近邻(k-Nearest Neighbor,KNN)分类算法,是一种基本分类算法,于1968年由Cover和Hart提出。该算法的主体思想是根据距离相近的邻居类别,来判定自己的所属类别。

2、K近邻算法原理

该算法的思路是:

1)计算测试对象与训练集中所有对象的距离,可以是欧式距离、余弦距离等,比较常用的是较为简单的欧式距离;
2)找出上步计算的距离中最近的K个对象,作为测试对象的邻居;
3)找出K个对象中出现频率最高的对象,其所属的类别就是该测试对象所属的类别。

k近邻算法的三个基本要素:k值的选择、距离度量及分类决策规则。

算法优缺点:

1)优点:思想简单,易于理解,易于实现,无需估计参数,无需训练;适合对稀有事物进行分类;特别适合于多分类问题。
2)缺点:懒惰算法,进行分类时计算量大,要扫描全部训练样本计算距离,内存开销大,评分慢;当样本不平衡时,如其中一个类别的样本较大,可能会导致对新样本计算近邻时,大容量样本占大多数,影响分类效果;可解释性较差,无法给出决策树那样的规则。

注意问题:

1)K值的设定:K值设置过小会降低分类精度;若设置过大,且测试样本属于训练集中包含数据较少的类,则会增加噪声,降低分类效果。通常,K值的设定采用交叉检验的方式(以K=1为基准)。经验规则:K一般低于训练样本数的平方根。

2)优化问题:压缩训练样本;确定最终的类别时,不是简单的采用投票法,而是进行加权投票,距离越近权重越高。

3、K近邻算法的实现:kd树

实现k近邻算法,主要考虑的问题是如何对训练数据进行快速k近邻搜索。这点在特征空间的维数大及训练数据容量大时尤为重要。为了提高k近邻搜索的效率,可以考虑使用特殊的结构存储训练数据,以减小计算距离的次数。

(1)构造kd树

kd树是一种分割K维数据空间的数据结构,属于二叉树。kd树的每个结点对应于一个k维超矩阵区域。

平衡kd树的构造:

输入:k维空间数据集$T = \left\{ {{x^{\left( 1 \right)}}{\rm{,}}{x^{\left( 2 \right)}}{\rm{,}} \cdots {\rm{,}}{x^{\left( N \right)}}} \right\}$

其中${x^{\left( i \right)}} = \left( {x_1^{\left( i \right)}{\rm{,}}x_2^{\left( i \right)}{\rm{,}} \cdots {\rm{,}}x_k^{\left( i \right)}} \right)$,$i = 1{\rm{,}}2{\rm{,}} \cdots {\rm{,N}}$,N为训练样本数

输出:kd树

1)开始:构造根节点,根节点对应于包含T的k维空间的超矩阵区域。

选择第一维度${x_1}$为坐标轴,以T中所有实例的${x_1}$坐标的中位数为切分点,将根节点对应的超矩阵区域切分为两个子区域。切分由通过切分点并且与坐标轴${x_1}$垂直的超平面实现。

由根节点生成深度为1的左、右子结点:左子结点对应坐标${x_1}$小于切分点的子区域,右子结点对应于坐标${x_1}$大于切分点的子区域。

将落在切分超平面上的实例点保存在根结点。

2)重复:对深度为$j$的结点,选择${x_l}$为切分点的坐标轴,$l = j\left( {{\rm{mod k}}} \right) + 1$,以该结点的区域中所有实例的${x_l}$坐标的中位点为切分点,将该结点对应的超矩形区域切分为两个子区域。切分由通过切分点并且与坐标轴${x_l}$垂直的超平面实现。

由根节点生成深度为$j+1$的左、右子结点:左子结点对应坐标${x_l}$小于切分点的子区域,右子结点对应于坐标${x_l}$大于切分点的子区域。

将落在切分超平面上的实例点保存在该结点。

3)直到两个子区域没有实例存在时停止,从而形成kd树的区域划分。

(2)搜索kd树(用kd树的最近邻搜索)

输入: 已构造的kd树;目标点x;

输出:x的最近邻。

1)在kd树中找出包含目标点x的叶结点:从根结点出发,递归地向下访问kd树。若目标点x当前维的坐标小于切分点的坐标,则移动到左子结点,否则移动到右子结点。直到子结点为叶结点为止。

2)以此叶结点为“当前最近点”。

3)递归地向上回退,在每个结点进行以下操作:

a)如果该结点保存的实例点比当前最近点距离目标点更近,则以该实例点为“当前最近点“。

b)当前最近点一定存在于该结点一个子结点对应的区域。检查该子结点的父结点的另一子结点对应的区域是否有更近的点。具体地,检查另一子结点对应的区域是否与以目标点为球心、以目标点与“当前最近点”间的距离为半径的超球体相交。

如果相交,可能在另一个子结点对应的区域内存在距目标点更近的点,移动到另一个子结点。接着,递归地进行最近邻搜索。

如果不相交,向上回退。

4)当回退到根结点时,搜索结束最后的“当前最近点”即为x的最近邻点。

如果实例点是随机分布的,kd树搜索的平均计算复杂度是O(log N),这里N是训练实例数。kd树更适用于训练实例数远大于空间维数时的k近邻搜索。当空间维数接近训练实例数时,它的效率会迅速下降,几乎接近线性扫描。

4、K近邻算法Python实践——识别手写数字

'''
Created on Sep 16, 2010
kNN: k Nearest Neighbors

Input:      inX: vector to compare to existing dataset (1xN)
            dataSet: size m data set of known vectors (NxM)
            labels: data set labels (1xM vector)
            k: number of neighbors to use for comparison (should be an odd number)
            
Output:     the most popular class label

@author: pbharrin
'''
import numpy as np
import operator
from os import listdir

# K近邻算法
# inX:用于分类的输入向量
# dataSet:输入的训练样本集
# labels: 标签向量
# k:选择最邻近的数目
# 按照欧式距离公式计算,并将结果由小到大排序
def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    diffMat = np.tile(inX, (dataSetSize,1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5
    sortedDistIndicies = distances.argsort()     
    classCount={}          
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]

#准备数据:将图像转化为测试向量
def img2vector(filename):
    returnVect = np.zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0,32*i+j] = int(lineStr[j])
    return returnVect

#测试算法,使用K-邻近算法识别手写数字
def handwritingClassTest():
    hwLabels = []
    trainingFileList = listdir('trainingDigits')           #load the training set
    m = len(trainingFileList)
    trainingMat = np.zeros((m,1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]     #take off .txt
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
    testFileList = listdir('testDigits')        #iterate through the test set
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]     #take off .txt
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
        print ("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
        if (classifierResult != classNumStr): errorCount += 1.0
    print ("\nthe total number of errors is: %d" % errorCount)
    print ("\nthe total error rate is: %f" % (errorCount/float(mTest)))
    
if __name__ == "__main__":
    #测试算法,使用K-邻近算法识别手写数字
    handwritingClassTest()
View Code

参考文献 

[1] 李航. 统计学习方法[M]. 北京:清华大学出版社,2012.

[2] Peter. 机器学习实战[M]. 北京:人民邮电出版社,2013.

[3] 赵志勇. Python机器学习算法[M]. 北京:电子工业出版社,2017.

[4] 周志华. 机器学习[M]. 北京:清华大学出版社,2016.

附录

数据集下载

链接:https://pan.baidu.com/s/1aO0y-LA37hkkzYI4BCQjmA
提取码:yoix

推荐阅读