首页 > 解决方案 > 不推荐使用 model.predict_classes - 改用什么?

问题描述

我一直在尝试重新访问我的 python 代码以在神经网络上进行预测,并且在运行model.predict_classes自 2021 年 1 月 1 日以来不推荐使用的代码后我意识到。

请你支持我知道我可以用什么代替我的代码吗?

代码行是:

y_pred_nn = model.predict_classes(X_test)

问题:

NameError
Traceback (most recent call last)
<ipython-input-11-fc1ddbecb622> in <module>
----> 1 print(y_pred_nn)

NameError: name 'y_pred_nn' is not defined

标签: pythontensorflowkerasneural-network

解决方案


有关如何处理此问题的最佳说明如下:

https://androidkt.com/get-class-labels-from-predict-method-in-keras/

首先用于model.predict()提取类概率。然后根据类的数量执行以下操作:

二进制分类

使用阈值选择将确定类别 0 或 1 的概率

np.where(y_pred > threshold, 1,0)

例如使用 0.5 的阈值

多级分类

选择概率最高的类

np.argmax(predictions, axis=1)

多标签分类

如果每个示例可以有多个输出类,请使用阈值来选择应用哪些标签。

y_pred = model.predict(x, axis=1)
[i for i,prob in enumerate(y_pred) if prob > 0.5]

推荐阅读