首页 > 解决方案 > Python - 'ndarray' 类型的对象不是 JSON 可序列化的

问题描述

我想了解以下功能有什么问题。它可以很好地绘制图表,但TypeError尽管这样做了,但它仍然显示。

我不在乎这个错误,并且已经尝试过使用try/except(indicating TypeError) 但由于我找不到的原因,它仍然显示错误。

我在函数下方发布,然后是错误输出。我想解决方案可能正在使用该.tolist()功能,但似乎无法抓住它。

def crossValPlot(skf,classifier,X,y):
    """Code adapted from:
        sklearn crossval example
    """
    from itertools import cycle
    from sklearn.metrics import roc_curve, auc
    from scipy import interp

    tprs = []
    aucs = []
    mean_fpr = np.linspace(0, 1, 100)
    idx = pd.IndexSlice
    f,ax = plt.subplots(figsize=(10,7))
    i = 0
    for train, test in skf.split(X, y):
        probas_ = (classifier.fit(X.iloc[idx[train]], y.iloc[idx[train]])
                   .predict_proba(X.iloc[idx[test]]))
        # Compute ROC curve and area the curve
        fpr, tpr, thresholds = roc_curve(y.iloc[idx[test]], probas_[:, 1])
        tprs.append(interp(mean_fpr, fpr, tpr))
        tprs[-1][0] = 0.0
        roc_auc = auc(fpr, tpr)
        aucs.append(roc_auc)
        ax.plot(fpr, tpr, lw=1, alpha=0.3,
                 label='ROC fold %d (AUC = %0.2f)' % (i, roc_auc))

        i += 1

    ax.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r',
             label='Luck', alpha=.8)

    mean_tpr = np.mean(tprs, axis=0)
    mean_tpr[-1] = 1.0
    mean_auc = auc(mean_fpr, mean_tpr)
    std_auc = np.std(aucs)
    ax.plot(mean_fpr, mean_tpr, color='b',
             label=r'Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc, std_auc),
             lw=2, alpha=.8)

    std_tpr = np.std(tprs, axis=0)
    tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
    tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
    ax.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2,
                     label=r'$\pm$ 1 std. dev.')

    ax.set_xlim([-0.05, 1.05])
    ax.set_ylim([-0.05, 1.05])
    ax.set_xlabel('False Positive Rate')
    ax.set_ylabel('True Positive Rate')
    ax.set_title('Receiver operating characteristic example')
    ax.legend(bbox_to_anchor=(1,1))

这是错误输出:

**---------------------------------------------------------------------------**
    TypeError                                 Traceback (most recent call last)
    ~\Anaconda3\lib\site-packages\IPython\core\formatters.py in __call__(self, obj)
        339                 pass
        340             else:
    --> 341                 return printer(obj)
        342             # Finally look for special method names
        343             method = get_real_method(obj, self.print_method)

    ~\Anaconda3\lib\site-packages\mpld3\_display.py in <lambda>(fig, kwds)
        408     formatter = ip.display_formatter.formatters['text/html']
        409     formatter.for_type(Figure,
    --> 410                        lambda fig, kwds=kwargs: fig_to_html(fig, **kwds))
        411 
        412 

    ~\Anaconda3\lib\site-packages\mpld3\_display.py in fig_to_html(fig, d3_url, mpld3_url, no_extras, template_type, figid, use_http, **kwargs)
        249                            d3_url=d3_url,
        250                            mpld3_url=mpld3_url,
    --> 251                            figure_json=json.dumps(figure_json, cls=NumpyEncoder),
        252                            extra_css=extra_css,
        253                            extra_js=extra_js)

    ~\Anaconda3\lib\json\__init__.py in dumps(obj, skipkeys, ensure_ascii, check_circular, allow_nan, cls, indent, separators, default, sort_keys, **kw)
        236         check_circular=check_circular, allow_nan=allow_nan, indent=indent,
        237         separators=separators, default=default, sort_keys=sort_keys,
    --> 238         **kw).encode(obj)
        239 
        240 

    ~\Anaconda3\lib\json\encoder.py in encode(self, o)
        197         # exceptions aren't as detailed.  The list call should be roughly
        198         # equivalent to the PySequence_Fast that ''.join() would do.
    --> 199         chunks = self.iterencode(o, _one_shot=True)
        200         if not isinstance(chunks, (list, tuple)):
        201             chunks = list(chunks)

    ~\Anaconda3\lib\json\encoder.py in iterencode(self, o, _one_shot)
        255                 self.key_separator, self.item_separator, self.sort_keys,
        256                 self.skipkeys, _one_shot)
    --> 257         return _iterencode(o, 0)
        258 
        259 def _make_iterencode(markers, _default, _encoder, _indent, _floatstr,

    ~\Anaconda3\lib\site-packages\mpld3\_display.py in default(self, obj)
        136             numpy.float64)):
        137             return float(obj)
    --> 138         return json.JSONEncoder.default(self, obj)
        139 
        140 

    ~\Anaconda3\lib\json\encoder.py in default(self, o)
        178         """
        179         raise TypeError("Object of type '%s' is not JSON serializable" %
    --> 180                         o.__class__.__name__)
        181 
        182     def encode(self, o):

    TypeError: Object of type 'ndarray' is not JSON serializable

标签: pythonnumpytypeerror

解决方案


推荐阅读