python - keras sequence().predict(x_test) 只为两个类返回 1 列
问题描述
我对 keras 有问题sequential().predict(x_test)
。
顺便说一句,使用sequential().predict_proba(x_test)
我发现这两个现在顺序无关紧要的输出相同。
我的数据有两个类别:0 或 1,我认为predict(x_test)
应该给出两列,其中第一列是获得 0 的概率,第二列是获得 1 的概率。但是我只有一列。
In [85]:y_train.value_counts()
Out[85]:
0 616751
1 11140
Name: _merge, dtype: int64
我的数据应该没有问题,因为我对 LogisticRegression 模型和神经网络模型都使用了相同的 x_train、y_train、x_test、y_test,它在 LogisticRegression 中完美运行。
In [87]:y_pred_LR
Out[87]:
array([[ 9.96117151e-01, 3.88284921e-03],
[ 9.99767583e-01, 2.32417329e-04],
[ 9.87375774e-01, 1.26242258e-02],
...,
[ 9.72159138e-01, 2.78408623e-02],
[ 9.97232916e-01, 2.76708432e-03],
[ 9.98146985e-01, 1.85301489e-03]])
但我在神经网络模型中只得到 1 列。
所以我想NN模型设置有问题吗?这是我的代码
NN = Sequential()
NN.add(Dense(40, input_dim = 65, kernel_initializer = 'uniform', activation = 'relu'))
NN.add(Dense(20, kernel_initializer = 'uniform', activation = 'relu'))
NN.add(Dense(1, kernel_initializer = 'uniform', activation = 'sigmoid'))
NN.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy'])
NN.fit(x_train, y_train, batch_size = 50, epochs=5)
y_pred_NN = NN.predict(x_test)
print(y_pred_NN)
In [86]: print(y_pred_NN)
[[ 0.00157279]
[ 0.0010451 ]
[ 0.03178826]
...,
[ 0.01030775]
[ 0.00584918]
[ 0.00186538]]
实际上它看起来像是获得1的概率?任何帮助表示赞赏!
顺便说一句,我在两个模型中的预测形状如下
In [91]:y_pred_LR.shape
Out[91]: (300000, 2)
In [90]:y_pred_NN.shape
Out[90]: (300000, 1)
解决方案
如果要输出两个概率,则必须替换y_train
为to_categorical(y_train)
,然后相应地调整网络:
from keras.utils import to_categorical
NN = Sequential()
NN.add(Dense(40, input_dim = 65, kernel_initializer = 'uniform', activation = 'relu'))
NN.add(Dense(20, kernel_initializer = 'uniform', activation = 'relu'))
NN.add(Dense(2, activation='sigmoid'))
NN.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy'])
NN.fit(x_train, to_categorical(y_train), batch_size = 50, epochs=5)
推荐阅读
- angularjs - AngularJS 1.6 和 $onInit-Hook
- vba - 根据其他单元格中的值清除单元格内容
- xpages - 如何使用 xpage 将文本文件写入特定文件夹
- c# - 与 WebBrowser 控件中的 AdobeReader 插件交互
- ios - 在展开 Optional 值时意外发现 nil?斯威夫特4
- android - CreateTabBottomNavigator 和滑动手势。刷卡很慢,为什么?
- linux - 如何测量交换信号的 2 个脚本 linux 之间的最大速度?
- django - Django CreateView 未按预期工作
- node.js - 无法在 TypeScript 中导入 Google Storage
- python - 如何使用 python 套接字通过 Internet 成功传输消息?