python - 我无法创建混淆矩阵
问题描述
我正在尝试创建混淆矩阵,但我不明白问题是什么。有没有可能我输入了不兼容的东西?
score = model.evaluate(X_test, y_test, verbose=1)
print("Test Score:", score[0])
print("Test Accuracy:", score[1])
# %%
y_pred = model.predict(X_test) > 0.4
cm = confusion_matrix(y_test.argmax(axis=1), y_pred.argmax(axis=1))
print(cm)
#plt.show(cm)
cm_df = pd.DataFrame(cm, index = ['aid_related','other_aid','request','weather_related','direct_report','infrastructure_related','medical','primary_needs'], columns = ['primary_needs','medical','infrastructure_related','direct_report','weather_related','request','other_aid','aid_related'])
plt.figure(figsize=(5,4))
sns.heatmap(cm_df, annot=True)
plt.title('Confusion Matrix')
plt.ylabel('Actal Values')
plt.xlabel('Predicted Values')
plt.show()
#print(preds)
#result = pd.DataFrame(preds, columns=[["aid_related","other_aid","request","weather_related","direct_report","infrastructure_related","medical","primary_needs"]])
#result.head(50)
Il messaggio di errore completo è il seguente
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
c:\Users\kaka0\Desktop\distaster_messages\Untitled-1 (1).py in <module>
156 cm = confusion_matrix(y_test.argmax(axis=1), y_pred.argmax(axis=1))
157 print(cm)
----> 158 plt.show(cm)
159 '''
160 cm_df = pd.DataFrame(cm, index = ['aid_related','other_aid','request','weather_related','direct_report','infrastructure_related','medical','primary_needs'], columns = ['primary_needs','medical','infrastructure_related','direct_report','weather_related','request','other_aid','aid_related'])
~\AppData\Local\Programs\Python\Python37\lib\site-packages\matplotlib\pyplot.py in show(*args, **kwargs)
376 """
377 _warn_if_gui_out_of_main_thread()
--> 378 return _backend_mod.show(*args, **kwargs)
379
380
~\AppData\Local\Programs\Python\Python37\lib\site-packages\matplotlib_inline\backend_inline.py in show(close, block)
47 # only call close('all') if any to close
48 # close triggers gc.collect, which can be slow
---> 49 if close and Gcf.get_all_fig_managers():
50 matplotlib.pyplot.close('all')
51
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
为什么总是给我这个错误?
解决方案
y_true 是真实标签,预测标签保存在变量预测中。X-test 是模型必须预测的输入标签。接下来我们从sklearn 导入confusion_matrix 并给它y_true 和y_pred 值。最后我们将混淆矩阵的输出放入热图中,您可以自己修改热图并选择是否保存。
请注意,在这种情况下,我在 x 和 y 标签中使用了男性和女性,您必须自己进行调整。
from sklearn.metrics import confusion_matrix
predictions = model.predict(X_test)
y_pred = np.round(predictions, 0)
y_true = y_true
confusion = np.round(confusion_matrix(y_true, y_pred, normalize='true'),2)
import seaborn as sb
from matplotlib.pyplot import figure
x_axis_labels = ['Female', 'Male']
y_axis_labels = ['Female', 'Male']
figure(figsize = (5,5))
sb.heatmap(confusion, xticklabels = x_axis_labels, yticklabels = y_axis_labels, cmap= "Blues", linecolor = 'black' , linewidth = .1 , annot = True, fmt='',)
plt.savefig('confusion_matrix_test.png')
plt.show()
推荐阅读
- python - 在 keras 中预测时索引超出范围
- uwp - 防止加载控件后面的键盘选项卡
- javascript - axios, fetch() 将请求标头设置为 Access-Control-Request-Headers 而不是单独的标头
- php - 如何将现有套接字转换为使用 TLS/SSL?
- javascript - Node.js /Express 和 mongoose:建立一个“可观察”的 mongodb 连接并自动拉取新数据?
- css - 如何解决这个问题?它是否与 webpack 一起使用?
- opengl - 聚光灯下的黑色边框伪影(OpenGL)
- doxygen - DoxyWizard 外观:如何让一切变得更大
- html - 字形图标未显示,但同事的代码相同
- angular6 - 如何在 Angular 6 中使用 exceljs