首页 > 技术文章 > kNN(k-Nearest Neighbors)JavaScript实现

xkxf 2022-02-24 15:59 原文



// Key terminology:
// training set has training examples (features + target variable)
// In the classification problem the target variables are called classes
// test set
// knowledge representation
function makeTrainingSet(dimension_of_feature, num_of_examples) {
    let result = [];
    for (let i = 0; i != num_of_examples; ++i) {
        let example = {
            features: [],
            label: Math.floor(Math.random() * 3), 
        for (let j = 0; j != dimension_of_feature; ++j) {
            example.features[j] = Math.random() * 10
    return result

let trainingSet = makeTrainingSet(3, 20) // 模拟训练集; 特征向量维数为3; 数据量为20

let features = makeTrainingSet(3, 1)[0].features // 此条数据用于测试

function getDistance(f1, f2) {
    const DIMENSION = f1.length
    // 将距离定义为欧氏距离
    let sumOfSquares = 0
    for (let i = 0; i != DIMENSION; ++i) {
        sumOfSquares += Math.pow(f1[i] - f2[i], 2)
    return Math.sqrt(sumOfSquares)

function getKNearestNeighbors(features, k, trainingSet) {    
    // 返回值: 数组; k个距离最近的邻居; 对象属性为【距离+标签】
    // 参数: k 邻居个数 ; features 未知数据的特征向量 ; trainingSet 已经贴了标签的数据集
    let distanceSet = trainingSet.map(training_example => {
        let distance = getDistance(features, training_example.features)
        return {
            distance: distance,
            label: training_example.label,
    distanceSet.sort((a, b) => a.distance - b.distance)
    return distanceSet.slice(0, k)

console.log(getKNearestNeighbors(features, 5, trainingSet))






<!DOCTYPE html>
    <meta charset="utf-8" />
    <!-- 引入刚刚下载的 ECharts 文件 -->
    <script src="echarts.js"></script>
        main {
            display: flex;
            flex-wrap: wrap;
            justify-content: center;
    <script type="text/javascript">
      // 获取一个[0,val)区间内的随机数
      function getRd(val) {      
        return Math.random() * val

      function make2DTrainingSet(num_of_examples) {       
        let result = [];
        for (let i = 0; i != num_of_examples; ++i) {
            let example = {
                features: [],
                label: Math.floor(getRd(4)), 
            if (example.label === 0) {
              example.features[0] = getRd(5)
              example.features[1] = getRd(5)
            } else if (example.label === 1) {
              example.features[0] = 5 + getRd(5)
              example.features[1] = getRd(5)             
            } else if (example.label === 2) {
              example.features[0] = 5 + getRd(5)
              example.features[1] = 5 + getRd(5)                 
            } else {
              example.features[0] = getRd(5)
              example.features[1] = 5 + getRd(5)   
        return result

      let trainingSet = make2DTrainingSet(100)

      const WIDTH = 600;
      const HEIGHT = 400;
      // main用来放div1
      let container = document.querySelector('main')
      let myDiv = createDiv(WIDTH, HEIGHT)
      let chart = echarts.init(myDiv)
        xAxis: {},
        yAxis: {},
        series: [
            type: 'scatter',
            data: trainingSet.filter(x => x.label === 0).map(x => x.features),
            type: 'scatter',
            data: trainingSet.filter(x => x.label === 1).map(x => x.features),
            type: 'scatter',
            data: trainingSet.filter(x => x.label === 2).map(x => x.features),
            type: 'scatter',
            data: trainingSet.filter(x => x.label === 3).map(x => x.features),

      function createDiv(width, height, display = 'block') {
        let result = document.createElement('div')
        result.style.width = width + 'px'
        result.style.height = height + 'px'
        // result.style.display = display
        return result

      function makeDate(dataSize) {
        let result = [];
        for (let i = 0; i != dataSize; ++i) {
            time: "t" + i,
            data: Math.random(),
        return result;


// 调整参数
const TEST_SET_SIZE = 100
const K = 15

// 获取一个[0,val)区间内的随机数
function getRd(val) {      
    return Math.random() * val

function make2DTrainingSet(num_of_examples) {       
    let result = [];
    for (let i = 0; i != num_of_examples; ++i) {
        let example = {
            features: [],
            label: Math.floor(getRd(4)), 
        if (example.label === 0) {
            example.features[0] = getRd(5)
            example.features[1] = getRd(5)
        } else if (example.label === 1) {
            example.features[0] = 5 + getRd(5)
            example.features[1] = getRd(5)             
        } else if (example.label === 2) {
            example.features[0] = 5 + getRd(5)
            example.features[1] = 5 + getRd(5)                 
        } else {
            example.features[0] = getRd(5)
            example.features[1] = 5 + getRd(5)   
    return result

function getDistance(f1, f2) {
    const DIMENSION = f1.length
    // 将距离定义为欧氏距离
    let sumOfSquares = 0
    for (let i = 0; i != DIMENSION; ++i) {
        sumOfSquares += Math.pow(f1[i] - f2[i], 2)
    return Math.sqrt(sumOfSquares)

function getKNearestNeighbors(features, k, trainingSet) {    
    // 返回值: 数组; k个距离最近的邻居; 对象属性为【距离+标签】
    // 参数: k 邻居个数 ; features 未知数据的特征向量 ; trainingSet 已经贴了标签的数据集
    let distanceSet = trainingSet.map(training_example => {
        let distance = getDistance(features, training_example.features)
        return {
            distance: distance,
            label: training_example.label,
    distanceSet.sort((a, b) => a.distance - b.distance)
    return distanceSet.slice(0, k)

function getTargetVariable(features, k, trainingSet) {
    let neighbors = getKNearestNeighbors(features, k, trainingSet)
    // console.log(neighbors)
    let myMap = new Map();
    neighbors.forEach(x => {
        if (myMap.get(x.label)) {
            myMap.set(x.label, myMap.get(x.label) + 1)
        } else {
            myMap.set(x.label, 1)  
    // console.log(myMap)
    let result
    let max = 0
    myMap.forEach((val, key) => {
        if (val > max) {
            result = key
            max = val
    return result

let trainingSet = make2DTrainingSet(TRAINING_SET_SIZE) 
let testSet = make2DTrainingSet(TEST_SET_SIZE) 

let correct = 0
for (let i = 0; i != testSet.length; ++i) {
    let testExample = testSet[i]
    let pv = getTargetVariable(testExample.features, K, trainingSet)
    if (pv == testExample.label) {
    } else {
        console.log('========W A=========')
        console.log(testExample.label + ' 预测成 ' + pv)
        console.log('features: ', testExample.features)

console.log('training set size:', TRAINING_SET_SIZE)
console.log('k:', K)
console.log('correct/total:', correct, TEST_SET_SIZE)
console.log('accurate: ', correct / testSet.length)



<!DOCTYPE html>
    <meta charset="utf-8" />
    <!-- 引入刚刚下载的 ECharts 文件 -->
    <script src="echarts.js"></script>
        main {
            display: flex;
            flex-wrap: wrap;
            justify-content: center;
      // 调整参数
      const TRAINING_SET_SIZE = 200
      const TEST_SET_SIZE = 100
      const K = 20

      // 获取一个[0,val)区间内的随机数
      function getRd(val) {      
          return Math.random() * val

      function make2DTrainingSet(num_of_examples) {       
          let result = [];
          for (let i = 0; i != num_of_examples; ++i) {
              let example = {
                  features: [],
                  label: Math.floor(getRd(4)), 
              if (example.label === 0) {
                  example.features[0] = getRd(5)
                  example.features[1] = getRd(5)
              } else if (example.label === 1) {
                  example.features[0] = 5 + getRd(5)
                  example.features[1] = getRd(5)             
              } else if (example.label === 2) {
                  example.features[0] = 5 + getRd(5)
                  example.features[1] = 5 + getRd(5)                 
              } else {
                  example.features[0] = getRd(5)
                  example.features[1] = 5 + getRd(5)   
          return result

      function getDistance(f1, f2) {
          const DIMENSION = f1.length
          // 将距离定义为欧氏距离
          let sumOfSquares = 0
          for (let i = 0; i != DIMENSION; ++i) {
              sumOfSquares += Math.pow(f1[i] - f2[i], 2)
          return Math.sqrt(sumOfSquares)

      function getKNearestNeighbors(features, k, trainingSet) {    
          // 返回值: 数组; k个距离最近的邻居; 对象属性为【距离+标签】
          // 参数: k 邻居个数 ; features 未知数据的特征向量 ; trainingSet 已经贴了标签的数据集
          let distanceSet = trainingSet.map(training_example => {
              let distance = getDistance(features, training_example.features)
              return {
                  distance: distance,
                  label: training_example.label,
          distanceSet.sort((a, b) => a.distance - b.distance)
          return distanceSet.slice(0, k)

      function getTargetVariable(features, k, trainingSet) {
          let neighbors = getKNearestNeighbors(features, k, trainingSet)
          let myMap = new Map();
          neighbors.forEach(x => {
              if (myMap.get(x.label)) {
                  myMap.set(x.label, myMap.get(x.label) + 1)
              } else {
                  myMap.set(x.label, 1)  
          let result
          let max = 0
          myMap.forEach((val, key) => {
              if (val > max) {
                  result = key
                  max = val
          return result

      let trainingSet = make2DTrainingSet(TRAINING_SET_SIZE) 
      let testSet = make2DTrainingSet(TEST_SET_SIZE) 

      let correct = 0
      let fail = []
      for (let i = 0; i != testSet.length; ++i) {
          let testExample = testSet[i]
          let pv = getTargetVariable(testExample.features, K, trainingSet)
          if (pv == testExample.label) {
          } else {
              console.log('========W A=========')
              console.log(testExample.label + ' 预测成 ' + pv)
              console.log('features: ', testExample.features)

      console.log('training set size:', TRAINING_SET_SIZE)
      console.log('k:', K)
      console.log('correct/total:', correct, TEST_SET_SIZE)
      console.log('accurate: ', correct / testSet.length)


    <!-- 展示数据 -->
    <script type="text/javascript">
      const WIDTH = 600;
      const HEIGHT = 400;
      // main用来放div1
      let container = document.querySelector('main')
      let myDiv = createDiv(WIDTH, HEIGHT)
      let chart = echarts.init(myDiv)

        xAxis: {},
        yAxis: {},
        series: [
            type: 'scatter',
            data: trainingSet.filter(x => x.label === 0).map(x => x.features),
            type: 'scatter',
            data: trainingSet.filter(x => x.label === 1).map(x => x.features),
            type: 'scatter',
            data: trainingSet.filter(x => x.label === 2).map(x => x.features),
            type: 'scatter',
            data: trainingSet.filter(x => x.label === 3).map(x => x.features),
            type: 'scatter',
            data: fail,

      function createDiv(width, height, display = 'block') {
        let result = document.createElement('div')
        result.style.width = width + 'px'
        result.style.height = height + 'px'
        // result.style.display = display
        return result

      function makeDate(dataSize) {
        let result = [];
        for (let i = 0; i != dataSize; ++i) {
            time: "t" + i,
            data: Math.random(),
        return result;





