首页 > 解决方案 > 使用交叉验证的 kNN 算法的参数

问题描述

我正在使用机器学习算法 kNN,而不是将数据集划分为 66.6% 的训练和 33.4% 的测试,我需要使用具有以下参数的交叉验证:K=3, 1/euclidean

K=3 没有什么玄机,我只是在代码中添加:

Classifier = KNeighborsClassifier(n_neighbors=3, p=2, metric='euclidean') 

它解决了。我无法理解的是1/euclidean,以及如何将其应用于代码?

import pandas as pd
import time
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import cross_val_score
from sklearn import metrics

def openfile():
   df = pd.read_csv('Testfile - kNN.csv')

   return df


def main():

   start_time = time.time()
   dataset = openfile()

   X = dataset.drop(columns=['Label'])
   y = dataset['Label'].values

   X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

   Classifier = KNeighborsClassifier(n_neighbors=3, p=2, metric='euclidean')
   Classifier.fit(X_train, y_train)

   y_pred_class = Classifier.predict(X_test)

   score = cross_val_score(Classifier, X, y, cv=10)

   y_pred_prob = Classifier.predict_proba(X_test)[:, 1]

   print("accuracy_score:", metrics.accuracy_score(y_test, y_pred_class),'\n')

   print("confusion matrix")
   print(metrics.confusion_matrix(y_test, y_pred_class),'\n')

   print("Background precision score:", metrics.precision_score(y_test, y_pred_class, labels=['background'], average='micro')*100,"%")
   print("Botnet precision score:", metrics.precision_score(y_test, y_pred_class, labels=['bot'], average='micro')*100,"%")
   print("Normal precision score:", metrics.precision_score(y_test, y_pred_class, labels=['normal'], average='micro')*100,"%",'\n')

   print(metrics.classification_report(y_test, y_pred_class, digits=2),'\n')
   print(score,'\n')
   print(score.mean(),'\n')


   print("--- %s seconds ---" % (time.time() - start_time))

标签: pythonmachine-learningscikit-learncross-validationknn

解决方案


您可以创建自己的函数并将其作为可调用参数传递给metric参数。

创建您的函数,如下所示:

from scipy.spatial import distance
def inverse_euc(a,b):
    return 1/distance.euclidean(a, b)

现在callable在你的KNN函数中使用它:

Classifier = KNeighborsClassifier(algorithm='ball_tree',n_neighbors=3, p=2, metric=inverse_euc)

推荐阅读