python - 编写机器学习分类器算法
问题描述
我试图为机器学习模型编写分类器算法,但出现错误。有人可以帮忙吗?提前致谢
import pandas as pd
from sklearn.metrics import accuracy_score
from scipy.spatial import distance
def euc(a, b):
return distance.euclidean(a,b)
class classifierKN():
def fit(self, X_train, Y_train):
self.X_train = X_train
self.Y_train = Y_train
def predict(self, X_test):
predictions = []
for row in X_test:
label = self.closest(row)
predictions.append(label)
return predictions
def closest(self, row):
best_dist = euc(row, self.X_train[0])
best_index = 0
for i in range(1, len(self.X_train)):
dist = euc(row, self.X_train[i])
if dist < best_dist:
best_dist = dist
best_index = i
return self.Y_train[best_index]
#Load the dataset
diabetdata = pd.read_csv("diabetes.csv")
#set features and target
features = ["PlasmaGlucose", "DiastolicBloodPressure", "TricepsThickness", "SerumInsulin"]
X = diabetdata[features]
print("FEATURES: " , X.head())
Y = diabetdata.Diabetic
print("TARGET: " , Y.head())
print("")
from sklearn.model_selection import train_test_split #No module named 'sklearn.cross_validation' so I replace it with model_selection
X_train, X_test, Y_train, Y_test = train_test_split(X,Y, test_size=0.3, random_state=0)
#predict
model= classifierKN()
model.fit(X_train,Y_train)
predictKN = model.predict(X)
print ("Predict result with KNeighborsClassifier")
print(predictKN)
#accuracy
print("Accuracy")
print (accuracy_score(Y, predictKN))
结果
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "C:\Users\Vlad\Desktop\Machine learning\Machine Learning\coursework\test2.py", line 63, in <module>
predictKN = model.predict(X)
File "C:\Users\Vlad\Desktop\Machine learning\Machine Learning\coursework\test2.py", line 26, in predict
label = self.closest(row)
File "C:\Users\Vlad\Desktop\Machine learning\Machine Learning\coursework\test2.py", line 30, in closest
best_dist = euc(row, self.X_train[0])
File "E:\Anaconda\lib\site-packages\pandas\core\frame.py", line 2800, in __getitem__
indexer = self.columns.get_loc(key)
File "E:\Anaconda\lib\site-packages\pandas\core\indexes\base.py", line 2648, in get_loc
return self._engine.get_loc(self._maybe_cast_indexer(key))
File "pandas\_libs\index.pyx", line 111, in pandas._libs.index.IndexEngine.get_loc
File "pandas\_libs\index.pyx", line 138, in pandas._libs.index.IndexEngine.get_loc
File "pandas\_libs\hashtable_class_helper.pxi", line 1619, in pandas._libs.hashtable.PyObjectHashTable.get_item
File "pandas\_libs\hashtable_class_helper.pxi", line 1627, in pandas._libs.hashtable.PyObjectHashTable.get_item
KeyError: 0
解决方案
您的代码实际上同时存在多个问题,因此理解它有点困难。您的问题似乎主要与您对 pandas Dataframes/Series 的理解有关,因为您显然试图通过以下方式迭代数据框的行:
def predict(self, X_test):
predictions = []
for row in X_test:
label = self.closest(row)
predictions.append(label)
return predictions
这在熊猫中不起作用。要实际迭代行的值,您需要以下内容:
def predict(self, X_test):
predictions = []
for row in X_test.iterrows():
label = self.closest(list(row[1]))
predictions.append(label)
return predictions
此函数实际上会遍历数据框中的行,并将行的值提供给closest()
函数。`
def closest(self, row):
best_dist = euc(row, self.X_train[0])
best_index = 0
for i in range(1, len(self.X_train)):
dist = euc(row, self.X_train[i])
if dist < best_dist:
best_dist = dist
best_index = i
return self.Y_train[best_index]
但是,此功能不起作用,因为您基本上是在尝试使用best_dist = euc(row, self.X_train[0])
. 这只会向您抛出一个 keyError,因为 X_train 是一个 Dataframe 并且没有第 0 列(无论如何您都不想索引该列)。您想要的是默认的 best_dist 作为输入行与数据框中第一行之间的距离。这将适用于类似
best_dist = euc(row, self.X_train.iloc[0])
. 然后,您需要遍历 X_train 中的行(这里您的函数存在与以前相同的问题),因此您需要将其更改为:
def closest(self, row):
best_dist = euc(row, self.X_train.iloc[0])
best_index = 0
for i in range(1, len(self.X_train.index)):
dist = euc(row, list(self.X_train.iloc[i]))
if dist < best_dist:
best_dist = dist
best_index = i
return self.Y_train.iloc[best_index]
这至少有效。无论它是否为您提供所需的输出和/或是否足够准确,我无法保证,但它确实解决了您的直接问题。
推荐阅读
- google-sheets - 如何根据另一个工作表中的列表过滤行?(谷歌表格)
- kubernetes - 如何从 Ansible 将变量传递给 Kubernetes YAML 文件?
- python - Beautiful Soup 4 Python 3.6.5 网络抓取,用于从英镑到美元的实时货币转换
- javascript - 尝试“放置”JSON 数据时出错
- angular - Angular5:ngOnChanges 被称为非常慢
- angular - Angular6.x 库未公开 public_api.ts 中的所有成员
- oracle-adf - ADF 应用程序模块是否支持 H2 或任何其他内存数据库
- python - Pandas/numpy 数组填充
- java - 使用循环将数字 1 添加到 n
- python - 即使在 Ajax 页面中显式等待后,Selenium 也找不到元素