python - 如何在模型内部改变 Keras 模型的输出阈值?
问题描述
我正在构建一个 Keras 模型,这样:
Y = 1 for X >= 0.5
Y = 0 for X < 0.5
我的模型:
def define_model():
model = Sequential()
model.build(input_shape = (None, 1))
model.add(Dense(1, activation = 'sigmoid'))
opt = SGD(learning_rate = 0.01, momentum = 0.99)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
return model
创建模型后,我将权重设置为 [1] 并将偏差设置为 [-0.5] 现在,我得到了很高的准确性,但以下输入的输出错误:
[[0.50000006]
[0.5 ]
[0.50000001]
[0.50000002]
[0.50000007]
[0.50000004]
[0.50000001]
[0.50000004]
[0.50000004]
[0.50000001]
[0.50000003]
[0.50000007]
[0.50000008]
[0.50000008]
[0.50000004]
[0.50000002]
[0.50000006]
[0.50000006]
[0.5000001 ]
[0.50000008]
[0.50000002]
[0.50000004]
[0.50000006]
[0.50000004]
[0.5 ]
[0.50000005]
[0.50000003]
[0.50000007]
[0.50000004]
etc.
所以,模型已经知道了,Y = 1 for only X > 0.5
但我需要Y = 1 for X >= 0.5
.
我知道这可以通过获取输出pred = model.predict(X)
然后手动比较来完成。但我希望这在模型内部完成。model.predict_classes
应该有一个门槛。我想改变这个阈值。我怎么能这样做?
解决方案
predict_classes
不允许我们更改阈值。这就是 keras 的实现方式
def predict_classes(self, x, batch_size=32, verbose=0):
proba = self.predict(x, batch_size=batch_size, verbose=verbose)
if proba.shape[-1] > 1:
return proba.argmax(axis=-1)
else:
return (proba > 0.5).astype('int32')
如果您想拥有自己的阈值,那么您将不得不重载该方法。
代码
class MySequential(keras.models.Sequential):
def __init__(self, **kwargs):
super(MySequential, self).__init__(**kwargs)
def predict_classes(self, x, batch_size=32, verbose=0):
proba = self.predict(x, batch_size=batch_size, verbose=verbose)
return (proba >= 0.6).astype('int32')
def define_model():
model = MySequential()
model.add(keras.layers.Dense(1, activation = 'sigmoid', input_shape=(None, 1)))
opt = keras.optimizers.SGD(learning_rate = 0.01, momentum = 0.99)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
return model
# Test
model = define_model()
x = np.random.randn(5)
print (model.predict(x))
print (model.predict_classes(x))
推荐阅读
- kubernetes - 具有复制功能的两台服务器上的 Kubernetes Cassandra 负载平衡
- jsf - 添加新数据后 Primefaces 数据表重新加载不起作用
- php - 为什么我的 AJAX 方法将表单输入发送到 URL?
- kubernetes - 在 Kubernetes 上采用 Unifi 设备
- postgresql - 如何在 postgres 中对相同的 CTE 表达式执行 UNION ALL?
- spyder - Outline Explorer Spyder 5 不显示模块、单元格等
- verilog - 何时在 SystemVerilog 中使用 `include
- vba - 如何使用 VBA 在 Powerpoint 中设置音频的启动选项
- reactjs - 如何模拟使用 Axios 进行的 API 调用?
- python - 如何使用 Tkinter 进行实时数据库更新