首页 > 解决方案 > scikitplot->IndexError:数组索引过多:数组是一维的,但有 2 个被索引

问题描述

谁能帮我解决这个错误?

y_test['cEXT'].shape ,y_test['cEXT'].ndim # returns ((982,), 1)
Y_test_probs.shape,Y_test_probs.ndim # returns ((982,), 1)

# AOC ROC Curve
import scikitplot as skplt
Y_test_probs = np.squeeze(model_cEXT.predict(X_test))

skplt.metrics.plot_roc_curve(y_test['cEXT'], Y_test_probs,
                       title="Digits ROC Curve", figsize=(12,6));

错误:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-136-33407ee6fd1a> in <module>()
      4 
      5 skplt.metrics.plot_roc_curve(y_test['cEXT'], Y_test_probs,
----> 6                        title="Digits ROC Curve", figsize=(12,6));

1 frames
/usr/local/lib/python3.7/dist-packages/scikitplot/metrics.py in plot_roc_curve(y_true, y_probas, title, curves, ax, figsize, cmap, title_fontsize, text_fontsize)
    255     roc_auc = dict()
    256     for i in range(len(classes)):
--> 257         fpr[i], tpr[i], _ = roc_curve(y_true, probas[:, i],
    258                                       pos_label=classes[i])
    259         roc_auc[i] = auc(fpr[i], tpr[i])

IndexError: too many indices for array: array is 1-dimensional, but 2 were indexed

标签: pythonscikit-plot

解决方案


似乎 y_probas 需要二维(n_samples,n_classes)

也许您可以尝试通过以下方式添加维度:

np.expand_dims(Y_test_probs,-1)


推荐阅读