python - 如何找到一个点的最近邻居和另一个不是最近邻居的点?
问题描述
我的任务是找到一些点的最近邻居,并删除其他不是最近邻居的点。这个任务就像下采样问题。
到目前为止的代码:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy import spatial
data = pd.read_csv('data.csv')
majority = data.loc[data['class']==0]
minority = data.loc[data['class']==1]
majority_points=majority.drop('class', axis=1)
minority_points=minority.drop('class', axis=1)
all_data = pd.concat([majority,minority])
data_points = all_data.drop('class', axis=1)
# print(data_points)
majority_points=np.array(majority_points)
print (majority_points)
minority_points =np.array(minority_points)
print (minority_points)
#result
[[1 1]
[1 2]
[1 3]
[1 4]
[1 5]
[2 1]
[2 2]
[2 4]
[2 5]
[3 1]
[3 2]
[3 5]
[4 1]
[4 4]
[4 5]
[5 1]
[5 2]
[5 3]
[5 4]
[5 5]] (20, 2)
[[2 3]
[3 3]
[3 4]
[4 2]
[4 3]]
`
#to find nearest neighbor
from scipy.spatial import distance
Y = distance.cdist(minority_points, majority_points, 'euclidean')
K = np.argsort(Y)
print (Y)
print ("Ordered data: \n", K)
Y.sort()
print ("After short: \n", Y)
#result
[[2.23606798 1.41421356 1. 1.41421356 2.23606798 2.
1. 1. 2. 2.23606798 1.41421356 2.23606798
2.82842712 2.23606798 2.82842712 3.60555128 3.16227766 3.
3.16227766 3.60555128]
[2.82842712 2.23606798 2. 2.23606798 2.82842712 2.23606798
1.41421356 1.41421356 2.23606798 2. 1. 2.
2.23606798 1.41421356 2.23606798 2.82842712 2.23606798 2.
2.23606798 2.82842712]
[3.60555128 2.82842712 2.23606798 2. 2.23606798 3.16227766
2.23606798 1. 1.41421356 3. 2. 1.
3.16227766 1. 1.41421356 3.60555128 2.82842712 2.23606798
2. 2.23606798]
[3.16227766 3. 3.16227766 3.60555128 4.24264069 2.23606798
2. 2.82842712 3.60555128 1.41421356 1. 3.16227766
1. 2. 3. 1.41421356 1. 1.41421356
2.23606798 3.16227766]
[3.60555128 3.16227766 3. 3.16227766 3.60555128 2.82842712
2.23606798 2.23606798 2.82842712 2.23606798 1.41421356 2.23606798
2. 1. 2. 2.23606798 1.41421356 1.
1.41421356 2.23606798]]
Ordered data:
[[ 2 6 7 1 3 10 5 8 0 13 11 9 4 12 14 17 16 18 15 19]
[10 6 7 13 9 2 17 11 1 3 5 8 18 12 14 16 0 4 15 19]
[ 7 11 13 8 14 3 18 10 19 2 4 17 6 1 16 9 12 5 15 0]
[16 10 12 9 17 15 6 13 5 18 7 1 14 0 2 11 19 8 3 4]
[17 13 16 10 18 14 12 9 15 11 19 7 6 5 8 2 3 1 4 0]]
After short:
[[1. 1. 1. 1.41421356 1.41421356 1.41421356
2. 2. 2.23606798 2.23606798 2.23606798 2.23606798
2.23606798 2.82842712 2.82842712 3. 3.16227766 3.16227766
3.60555128 3.60555128]
[1. 1.41421356 1.41421356 1.41421356 2. 2.
2. 2. 2.23606798 2.23606798 2.23606798 2.23606798
2.23606798 2.23606798 2.23606798 2.23606798 2.82842712 2.82842712
2.82842712 2.82842712]
[1. 1. 1. 1.41421356 1.41421356 2.
2. 2. 2.23606798 2.23606798 2.23606798 2.23606798
2.23606798 2.82842712 2.82842712 3. 3.16227766 3.16227766
3.60555128 3.60555128]
[1. 1. 1. 1.41421356 1.41421356 1.41421356
2. 2. 2.23606798 2.23606798 2.82842712 3.
3. 3.16227766 3.16227766 3.16227766 3.16227766 3.60555128
3.60555128 4.24264069]
[1. 1. 1.41421356 1.41421356 1.41421356 2.
2. 2.23606798 2.23606798 2.23606798 2.23606798 2.23606798
2.23606798 2.82842712 2.82842712 3. 3.16227766 3.16227766
3.60555128 3.60555128]]
我想将少数点中每个点的 3 个最近邻点设为多数点,并保留其数组的值,其余的被删除。
这是插图:
红点是少数示例,蓝点是多数示例。因此,每个少数类计算它的,例如,与多数最近的 3 个邻居。然后该算法删除了一些不是最近邻居的点。
解决方案
看来您已经走得很远了,您需要一些帮助才能继续前进。您对距离进行排序的方式存在一个问题。您首先创建一个距离数组,然后对它们进行排序,但这样做会丢失每个距离的上下文信息。你有所有的距离,你对它们进行了排序,但你不知道它们适用于哪些点。这一步对我来说似乎没有必要,您可以停止使用K
哪个是现有项目的索引(不会丢失上下文信息),同时获得有关其排序位置的信息。如果您看一下K
,您可能会注意到它是一个 5x20 矩阵,并且给定您的majority_points
和的形状minoirty_points
,(分别为 20x2 和 5x2)这表明K
n
minority_point
是一个二维矩阵,每一行是给定所有m
majority_point
s之间的排序距离的索引。
让我们看一个例子。第一行K
是
[ 2 6 7 1 3 10 5 8 0 13 11 9 4 12 14 17 16 18 15 19]
这意味着首先minority_point
[2, 3]
,最接近的多数点在索引 2、6、7、1、3 等处按顺序排列,结果按[1, 3]
[2, 2]
[2, 4]
[1, 2]
顺序排列。如果你看一下,前 3 个是最短的距离,与minority_point
有问题的距离只有 1 个单位。(第一个是垂直的,第二个是水平的。) (请注意,排序算法将最接近的项目和第一个项目都放置在输入数组中。因此索引 2 位于 6 或 7 之前,即使它们都是相同的距离。)
因此,您需要做的就是创建一个新数组,复制 K 中每一行的前 3 个索引处的项目。这与majority_points
为每个minority_points
.
在 numpy 约定中,这意味着查看majority_points
数组(而不是复制整个内容。)以下行应该可以工作(我测试过):
majority_points[K[:,0:3]]
这意味着您从所有行(第一个维度,逗号之前,都只是,)和第二个维度的前 3 个元素中获取 () 中的元素K
。这些是您想从中获取的索引。K[]
:
majority_points
我的输出看起来像:
[[[1 3]
[2 2]
[2 4]]
[[3 2]
[2 2]
[2 4]]
[[2 4]
[3 5]
[4 4]]
[[5 2]
[3 2]
[4 1]]
[[5 3]
[4 4]
[5 2]]]
推荐阅读
- javascript - 如何将数据从嵌套对象合并到一个数组(使用地图)
- django - 如何在 django 中更新实例后将成功 url 设置为上一页
- android - 尝试在 android studio 上运行应用程序时收到“访问被拒绝”消息
- angular - [style.width.px] 中的角度条件
- php - Yii 1.1 允许的内存大小为 536870912 字节耗尽(尝试分配 72 字节)
- node.js - aws 用 node.js 中的 aws lambda 理解
- ios - 如何处理 AVPlayer 中的超时请求失败?
- microsoft-graph-api - 从 Microsoft Teams 中的 MS Graph API 调用 Read Reports API,但使用此 Api
- typo3-extensions - Typo3 10.4.1 扩展生成器:没有为新扩展创建表
- jmeter - 使用jmeter在html响应中提取值的问题