首页 > 解决方案 > 我无法创建混淆矩阵

问题描述

我正在尝试创建混淆矩阵,但我不明白问题是什么。有没有可能我输入了不兼容的东西?

score = model.evaluate(X_test, y_test, verbose=1)

print("Test Score:", score[0])
print("Test Accuracy:", score[1])
# %%

y_pred = model.predict(X_test) > 0.4
cm = confusion_matrix(y_test.argmax(axis=1), y_pred.argmax(axis=1))
print(cm)
#plt.show(cm)
cm_df = pd.DataFrame(cm, index = ['aid_related','other_aid','request','weather_related','direct_report','infrastructure_related','medical','primary_needs'], columns = ['primary_needs','medical','infrastructure_related','direct_report','weather_related','request','other_aid','aid_related'])
plt.figure(figsize=(5,4))
sns.heatmap(cm_df, annot=True)
plt.title('Confusion Matrix')
plt.ylabel('Actal Values')
plt.xlabel('Predicted Values')
plt.show()
#print(preds)
#result = pd.DataFrame(preds, columns=[["aid_related","other_aid","request","weather_related","direct_report","infrastructure_related","medical","primary_needs"]])
#result.head(50)

Il messaggio di errore completo è il seguente

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
c:\Users\kaka0\Desktop\distaster_messages\Untitled-1 (1).py in <module>
      156 cm = confusion_matrix(y_test.argmax(axis=1), y_pred.argmax(axis=1))
      157 print(cm)
----> 158 plt.show(cm)
      159 '''
      160 cm_df = pd.DataFrame(cm, index = ['aid_related','other_aid','request','weather_related','direct_report','infrastructure_related','medical','primary_needs'], columns = ['primary_needs','medical','infrastructure_related','direct_report','weather_related','request','other_aid','aid_related'])

~\AppData\Local\Programs\Python\Python37\lib\site-packages\matplotlib\pyplot.py in show(*args, **kwargs)
    376     """
    377     _warn_if_gui_out_of_main_thread()
--> 378     return _backend_mod.show(*args, **kwargs)
    379 
    380 

~\AppData\Local\Programs\Python\Python37\lib\site-packages\matplotlib_inline\backend_inline.py in show(close, block)
     47         # only call close('all') if any to close
     48         # close triggers gc.collect, which can be slow
---> 49         if close and Gcf.get_all_fig_managers():
     50             matplotlib.pyplot.close('all')
     51 

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()


为什么总是给我这个错误?

标签: pythondeep-learningsklearn-pandas

解决方案


y_true 是真实标签,预测标签保存在变量预测中。X-test 是模型必须预测的输入标签。接下来我们从sklearn 导入confusion_matrix 并给它y_true 和y_pred 值。最后我们将混淆矩阵的输出放入热图中,您可以自己修改热图并选择是否保存。

请注意,在这种情况下,我在 x 和 y 标签中使用了男性和女性,您必须自己进行调整。

from sklearn.metrics import confusion_matrix 

predictions = model.predict(X_test)

y_pred = np.round(predictions, 0)
y_true = y_true

confusion = np.round(confusion_matrix(y_true, y_pred, normalize='true'),2)

import seaborn as sb
from matplotlib.pyplot import figure
x_axis_labels = ['Female', 'Male']
y_axis_labels = ['Female', 'Male']
figure(figsize = (5,5))

sb.heatmap(confusion, xticklabels = x_axis_labels, yticklabels = y_axis_labels,  cmap= "Blues", linecolor = 'black' , linewidth = .1 , annot = True, fmt='',)
plt.savefig('confusion_matrix_test.png')
plt.show()

推荐阅读