python - 多类分类和概率预测
问题描述
import pandas as pd
import numpy
from sklearn import cross_validation
from sklearn.naive_bayes import GaussianNB
fi = "df.csv"
# Open the file for reading and read in data
file_handler = open(fi, "r")
data = pd.read_csv(file_handler, sep=",")
file_handler.close()
# split the data into training and test data
train, test = cross_validation.train_test_split(data,test_size=0.6, random_state=0)
# initialise Gaussian Naive Bayes
naive_b = GaussianNB()
train_features = train.ix[:,0:127]
train_label = train.iloc[:,127]
test_features = test.ix[:,0:127]
test_label = test.iloc[:,127]
naive_b.fit(train_features, train_label)
test_data = pd.concat([test_features, test_label], axis=1)
test_data["p_malw"] = naive_b.predict_proba(test_features)
print "test_data\n",test_data["p_malw"]
print "Accuracy:", naive_b.score(test_features,test_label)
我编写了这段代码来接受来自具有 128 列的 csv 文件的输入,其中 127 列是特征,第 128 列是类标签。
我想预测样本属于每个类别的概率(有 5 个类别(1-5))并将其打印在矩阵中并根据预测确定样本类别。predict_proba() 没有给出所需的输出。请提出所需的更改。
解决方案
GaussianNB.predict_proba 返回模型中每个类的样本概率。在您的情况下,它应该返回一个包含五列的结果,这些列的行数与您的测试数据中的行数相同。您可以使用 naive_b.classes_ 验证哪个列对应于哪个类。因此,尚不清楚您为什么说这不是所需的输出。也许,您的问题来自您将预测概率的输出分配给数据框列的事实。尝试:
pred_prob = naive_b.predict_proba(test_features)
代替
test_data["p_malw"] = naive_b.predict_proba(test_features)
并使用 pred_prob.shape 验证其形状。第二个维度应该是 5。
如果您想要每个样本的预测标签,您可以使用 predict 方法,然后使用混淆矩阵来查看正确预测了多少标签。
from sklearn.metrics import confusion_matrix
naive_B.fit(train_features, train_label)
pred_label = naive_B.predict(test_features)
confusion_m = confusion_matrix(test_label, pred_label)
confusion_m
这里有一些有用的读物。
sklearn GaussianNB - http://scikit-learn.org/stable/modules/generated/sklearn.naive_bayes.GaussianNB.html#sklearn.naive_bayes.GaussianNB.predict_proba
sklearn 混淆矩阵 - http://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html
推荐阅读
- python - 如何在 Pandas 中将值从 1 列分配给另一列并发出警告
- html - 如何在弹性容器中设置间隙(排水沟)?
- reactjs - 无法使用 fontawesome pro 将 gatsby 站点部署到 Netlify
- javascript - babel-plugin-rewire: 测试私有方法
- oracle - 如何动态使用源/目标表进行合并以及在 Oracle 中动态选择更新语句的列
- c++ - 如何在 Ubuntu 的 VSCodium 中设置 Googletest
- string - Swift 5:识别字符串中所有不同的子字符串
- amazon-web-services - AWS RDS Aurora - 如何使用 PgAdmin 进行连接?
- c# - 无法将 lambda 表达式转换为类型 'System.Linq.Expressions.Expression
' 因为它不是委托类型。挂火错误 - php - 未使用自定义 MAIL_LOG_CHANNEL