首页 > 解决方案 > 如何在模型内部改变 Keras 模型的输出阈值?

问题描述

我正在构建一个 Keras 模型,这样:

Y = 1 for X >= 0.5
Y = 0 for X < 0.5

我的模型:

def define_model():
    model = Sequential()
    model.build(input_shape = (None, 1))
    model.add(Dense(1, activation = 'sigmoid'))

    opt = SGD(learning_rate = 0.01, momentum = 0.99)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    
    return model

创建模型后,我将权重设置为 [1] 并将偏差设置为 [-0.5] 现在,我得到了很高的准确性,但以下输入的输出错误:

[[0.50000006]
 [0.5       ]
 [0.50000001]
 [0.50000002]
 [0.50000007]
 [0.50000004]
 [0.50000001]
 [0.50000004]
 [0.50000004]
 [0.50000001]
 [0.50000003]
 [0.50000007]
 [0.50000008]
 [0.50000008]
 [0.50000004]
 [0.50000002]
 [0.50000006]
 [0.50000006]
 [0.5000001 ]
 [0.50000008]
 [0.50000002]
 [0.50000004]
 [0.50000006]
 [0.50000004]
 [0.5       ]
 [0.50000005]
 [0.50000003]
 [0.50000007]
 [0.50000004]
etc.

所以,模型已经知道了,Y = 1 for only X > 0.5但我需要Y = 1 for X >= 0.5.

我知道这可以通过获取输出pred = model.predict(X)然后手动比较来完成。但我希望这在模型内部完成。model.predict_classes应该有一个门槛。我想改变这个阈值。我怎么能这样做?

标签: pythontensorflowkerasneural-networkpredict

解决方案


predict_classes不允许我们更改阈值。这就是 keras 的实现方式

def predict_classes(self, x, batch_size=32, verbose=0):
    proba = self.predict(x, batch_size=batch_size, verbose=verbose)
    if proba.shape[-1] > 1:
      return proba.argmax(axis=-1)
    else:
      return (proba > 0.5).astype('int32')

如果您想拥有自己的阈值,那么您将不得不重载该方法。

代码

class MySequential(keras.models.Sequential):
  def __init__(self, **kwargs):
    super(MySequential, self).__init__(**kwargs)

  def predict_classes(self, x, batch_size=32, verbose=0):
    proba = self.predict(x, batch_size=batch_size, verbose=verbose)
    return (proba >= 0.6).astype('int32')

def define_model():
    model = MySequential()
    model.add(keras.layers.Dense(1, activation = 'sigmoid', input_shape=(None, 1)))

    opt = keras.optimizers.SGD(learning_rate = 0.01, momentum = 0.99)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    
    return model

# Test

model = define_model()
x = np.random.randn(5)
print (model.predict(x))
print (model.predict_classes(x))

推荐阅读