首页 > 解决方案 > KNNImputer 与 scikit-learn

问题描述

我对具有 14 个不同传感器的引擎进行了测量,测量间隔为 6 秒,它显示为具有 14 列和大约 5000 行的 numpy 数组。每行代表一个测量点,每列代表一个特征。

1% 的数据集是 NaN,我想估算它们以将它们与 SVM 一起使用。

由于数据集是动态引擎的时间序列,因此仅查看缺失值的 2 个最近的数据点才有意义:一个数据点之前和一个数据点之后。它应该计算 2 个最近数据点的平均值。

我认为这应该可以通过 scikit-learn 的 KNNImputer 实现,但是当我有一整行 NaN 时我并不满意。看这个例子:

15.30      80.13   20000.00   15000.00     229.00     698.00     590.00      24.00      82.53      1522.00     410.00     406.00     407.00      50.01
nan        nan        nan        nan        nan        nan        nan        nan        nan        nan        nan        nan        nan        nan
15.30      82.90   20000.00   15000.00     225.00     698.00     628.00      24.00      85.36    1523.00     410.00     407.00     408.00      50.02

KNNImputer 的输出如下所示:

15.30      80.13   20000.00   15000.00     229.00     698.00     590.00      24.00      82.53    1522.00     410.00     406.00     407.00      50.01
19.90      91.88   19997.09   19945.58     327.14     829.40     651.23      25.97      94.80    1529.65     410.20     406.69     407.72      49.99
15.30      82.90   20000.00   15000.00     225.00     698.00     628.00      24.00      85.36    1523.00     410.00     407.00     408.00      50.02

看第一列,除了 NaN:(15.30 + 15.30)/2=15.30

相反,我得到了 19.90。

我的代码:

from sklearn.impute import KNNImputer

imp = KNNImputer(n_neighbors=2)  
X_afterImputer = imp.fit_transform(X_beforeImputer)

有什么想法吗?

标签: pythonscikit-learn

解决方案


我为你做了一个函数。这是一个可重现的示例,因此您可以了解它是如何工作的:

import numpy as np

arr = np.random.randint(0, 10, (10, 4)).astype(float)

arr[2, 0] = np.nan
arr[4, 3] = np.nan
arr[0, 2] = np.nan

print(arr)
[[ 5.  7. nan  4.]
 [ 2.  6.  4.  9.]
 [nan  2.  5.  5.]
 [ 7.  0.  3.  8.]
 [ 6.  4.  3. nan]
 [ 8.  1.  2.  0.]
 [ 0.  0.  1.  1.]
 [ 1.  2.  6.  6.]
 [ 8.  1.  9.  7.]
 [ 3.  5.  8.  8.]]
for x in np.argwhere(np.isnan(arr)):
    sample = arr[np.maximum(x[0] - 1, 0):np.minimum(x[0] + 2, 20), x[1]]
    arr[x[0], x[1]] = np.mean(sample[np.logical_not(np.isnan(sample))])
print(arr)
[[5.  7.  4.  4. ] # 3rd value here is mean(4)
 [2.  6.  4.  9. ]
 [4.5 2.  5.  5. ] # first value here is mean(2, 7)
 [7.  0.  3.  8. ]
 [6.  4.  3.  4. ] # 4th value here is mean(8, 0)
 [8.  1.  2.  0. ]
 [0.  0.  1.  1. ]
 [1.  2.  6.  6. ]
 [8.  1.  9.  7. ]
 [3.  5.  8.  8. ]]

逻辑如下:

for every location (x, y) where value is missing:
    take previous and next value (if possible)
    assign the mean of these two values to the location (x, y)

推荐阅读