首页 > 解决方案 > 如何找到一个点的最近邻居和另一个不是最近邻居的点?

问题描述

我的任务是找到一些点的最近邻居,并删除其他不是最近邻居的点。这个任务就像下采样问题。

到目前为止的代码:

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy import spatial

data = pd.read_csv('data.csv')
majority = data.loc[data['class']==0]
minority = data.loc[data['class']==1]

majority_points=majority.drop('class', axis=1)
minority_points=minority.drop('class', axis=1)

all_data = pd.concat([majority,minority])

data_points = all_data.drop('class', axis=1)
# print(data_points)

majority_points=np.array(majority_points)
print (majority_points)
minority_points =np.array(minority_points)
print (minority_points)
    #result
    [[1 1]
 [1 2]
 [1 3]
 [1 4]
 [1 5]
 [2 1]
 [2 2]
 [2 4]
 [2 5]
 [3 1]
 [3 2]
 [3 5]
 [4 1]
 [4 4]
 [4 5]
 [5 1]
 [5 2]
 [5 3]
 [5 4]
 [5 5]] (20, 2)
[[2 3]
 [3 3]
 [3 4]
 [4 2]
 [4 3]]

`

#to find nearest neighbor
from scipy.spatial import distance
Y = distance.cdist(minority_points, majority_points, 'euclidean')

K = np.argsort(Y)
print (Y)
print ("Ordered data: \n", K)
Y.sort()
print ("After short: \n", Y)
#result
[[2.23606798 1.41421356 1.         1.41421356 2.23606798 2.
  1.         1.         2.         2.23606798 1.41421356 2.23606798
  2.82842712 2.23606798 2.82842712 3.60555128 3.16227766 3.
  3.16227766 3.60555128]
 [2.82842712 2.23606798 2.         2.23606798 2.82842712 2.23606798
  1.41421356 1.41421356 2.23606798 2.         1.         2.
  2.23606798 1.41421356 2.23606798 2.82842712 2.23606798 2.
  2.23606798 2.82842712]
 [3.60555128 2.82842712 2.23606798 2.         2.23606798 3.16227766
  2.23606798 1.         1.41421356 3.         2.         1.
  3.16227766 1.         1.41421356 3.60555128 2.82842712 2.23606798
  2.         2.23606798]
 [3.16227766 3.         3.16227766 3.60555128 4.24264069 2.23606798
  2.         2.82842712 3.60555128 1.41421356 1.         3.16227766
  1.         2.         3.         1.41421356 1.         1.41421356
  2.23606798 3.16227766]
 [3.60555128 3.16227766 3.         3.16227766 3.60555128 2.82842712
  2.23606798 2.23606798 2.82842712 2.23606798 1.41421356 2.23606798
  2.         1.         2.         2.23606798 1.41421356 1.
  1.41421356 2.23606798]]
Ordered data: 
 [[ 2  6  7  1  3 10  5  8  0 13 11  9  4 12 14 17 16 18 15 19]
 [10  6  7 13  9  2 17 11  1  3  5  8 18 12 14 16  0  4 15 19]
 [ 7 11 13  8 14  3 18 10 19  2  4 17  6  1 16  9 12  5 15  0]
 [16 10 12  9 17 15  6 13  5 18  7  1 14  0  2 11 19  8  3  4]
 [17 13 16 10 18 14 12  9 15 11 19  7  6  5  8  2  3  1  4  0]]
After short: 
 [[1.         1.         1.         1.41421356 1.41421356 1.41421356
  2.         2.         2.23606798 2.23606798 2.23606798 2.23606798
  2.23606798 2.82842712 2.82842712 3.         3.16227766 3.16227766
  3.60555128 3.60555128]
 [1.         1.41421356 1.41421356 1.41421356 2.         2.
  2.         2.         2.23606798 2.23606798 2.23606798 2.23606798
  2.23606798 2.23606798 2.23606798 2.23606798 2.82842712 2.82842712
  2.82842712 2.82842712]
 [1.         1.         1.         1.41421356 1.41421356 2.
  2.         2.         2.23606798 2.23606798 2.23606798 2.23606798
  2.23606798 2.82842712 2.82842712 3.         3.16227766 3.16227766
  3.60555128 3.60555128]
 [1.         1.         1.         1.41421356 1.41421356 1.41421356
  2.         2.         2.23606798 2.23606798 2.82842712 3.
  3.         3.16227766 3.16227766 3.16227766 3.16227766 3.60555128
  3.60555128 4.24264069]
 [1.         1.         1.41421356 1.41421356 1.41421356 2.
  2.         2.23606798 2.23606798 2.23606798 2.23606798 2.23606798
  2.23606798 2.82842712 2.82842712 3.         3.16227766 3.16227766
  3.60555128 3.60555128]]

我想将少数点中每个点的 3 个最近邻点设为多数点,并保留其数组的值,其余的被删除。

这是插图:

  1. 重采样/原始数据集之前
  2. 重采样后

红点是少数示例,蓝点是多数示例。因此,每个少数类计算它的,例如,与多数最近的 3 个邻居。然后该算法删除了一些不是最近邻居的点。

标签: pythonpandasnumpy

解决方案


看来您已经走得很远了,您需要一些帮助才能继续前进。您对距离进行排序的方式存在一个问题。您首先创建一个距离数组,然后对它们进行排序,但这样做会丢失每个距离的上下文信息。你有所有的距离,你对它们进行了排序,但你不知道它们适用于哪些点。这一步对我来说似乎没有必要,您可以停止使用K哪个是现有项目的索引(不会丢失上下文信息),同时获得有关其排序位置的信息。如果您看一下K,您可能会注意到它是一个 5x20 矩阵,并且给定您的majority_points和的形状minoirty_points,(分别为 20x2 和 5x2)这表明Kn minority_point是一个二维矩阵,每一行是给定所有m majority_points之间的排序距离的索引。

让我们看一个例子。第一行K

[ 2  6  7  1  3 10  5  8  0 13 11  9  4 12 14 17 16 18 15 19]

这意味着首先minority_point [2, 3],最接近的多数点在索引 2、6、7、1、3 等处按顺序排列,结果按[1, 3] [2, 2] [2, 4] [1, 2]顺序排列。如果你看一下,前 3 个是最短的距离,与minority_point有问题的距离只有 1 个单位。(第一个是垂直的,第二个是水平的。) (请注意,排序算法将最接近的项目和第一个项目都放置在输入数组中。因此索引 2 位于 6 或 7 之前,即使它们都是相同的距离。)

因此,您需要做的就是创建一个新数组,复制 K 中每一行的前 3 个索引处的项目。这与majority_points为每个minority_points.

在 numpy 约定中,这意味着查看majority_points数组(而不是复制整个内容。)以下行应该可以工作(我测试过):

majority_points[K[:,0:3]]

这意味着您从所有行(第一个维度,逗号之前,都只是,)和第二个维度的前 3 个元素中获取 () 中的元素K。这些是您想从中获取的索引。K[]:majority_points

我的输出看起来像:

[[[1 3]
  [2 2]
  [2 4]]

 [[3 2]
  [2 2]
  [2 4]]

 [[2 4]
  [3 5]
  [4 4]]

 [[5 2]
  [3 2]
  [4 1]]

 [[5 3]
  [4 4]
  [5 2]]]

推荐阅读