首页 > 解决方案 > 为什么我的二元分类模型不学习,甚至过度拟合?

问题描述

我有以下模型,使用带有 keras 的 tensorflow 2.2.0:

def get_model(input_shape):
  model = keras.Sequential()
  
  model.add(Conv2D(32, input_shape=input_shape, kernel_size=(3, 3)))
  model.add(MaxPooling2D(pool_size=(2, 2)))

  model.add(Conv2D(64, kernel_size=(3, 3)))
  model.add(MaxPooling2D(pool_size=(2, 2)))
  
  model.add(Conv2D(64, kernel_size=(3, 3)))
  model.add(MaxPooling2D(pool_size=(2, 2)))

  model.add(Flatten())

  model.add(Dense(64, activation='relu'))
  model.add(Dense(1, activation='sigmoid'))

  return model

输入的形状是(25, 25, 4)- 一个 3 维图像,25x25px,有 4 个通道。模型不会学习——它甚至不会过拟合!我正在尝试使用以下咒语来适应:

model.compile(optimizer='sgd', metrics=['accuracy'], loss='binary_crossentropy')
model.fit(trainX, trainY, validation_split=0.2, epochs=10, batch_size=50)

我还尝试将优化器更改为sgd具有相同的结果,并尝试了不同的批量大小(包括 1)。10 个 epoch 的训练示例:

Epoch 1/10
763/763 [==============================] - 4s 5ms/step - loss: 0.6935 - accuracy: 0.5045 - val_loss: 0.6937 - val_accuracy: 0.5031
Epoch 2/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6935 - accuracy: 0.5020 - val_loss: 0.6946 - val_accuracy: 0.4972
Epoch 3/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6935 - accuracy: 0.5016 - val_loss: 0.6932 - val_accuracy: 0.4984
Epoch 4/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6934 - accuracy: 0.5020 - val_loss: 0.6932 - val_accuracy: 0.4986
Epoch 5/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6933 - accuracy: 0.5027 - val_loss: 0.6934 - val_accuracy: 0.4972
Epoch 6/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6932 - accuracy: 0.5051 - val_loss: 0.6946 - val_accuracy: 0.5019
Epoch 7/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6933 - accuracy: 0.5017 - val_loss: 0.6932 - val_accuracy: 0.4959
Epoch 8/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6933 - accuracy: 0.5017 - val_loss: 0.6934 - val_accuracy: 0.5056
Epoch 9/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6932 - accuracy: 0.5040 - val_loss: 0.6931 - val_accuracy: 0.5009
Epoch 10/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6933 - accuracy: 0.5018 - val_loss: 0.6931 - val_accuracy: 0.5020
<tensorflow.python.keras.callbacks.History at 0x7f761a0856d8>

就其价值而言,数据几乎肯定不是问题——我尝试了其他机器学习方法,例如随机森林和梯度提升,它们能够很好地过拟合。

我在这里错过了一些基本的东西吗?

编辑:将转换层的激活设置为relu无济于事。下面的输出是relu


Epoch 1/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6936 - accuracy: 0.4990 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 2/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6933 - accuracy: 0.5026 - val_loss: 0.6931 - val_accuracy: 0.5043
Epoch 3/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6933 - accuracy: 0.5000 - val_loss: 0.6932 - val_accuracy: 0.4971
Epoch 4/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5004 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 5/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.4992 - val_loss: 0.6932 - val_accuracy: 0.5029
Epoch 6/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5031 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 7/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5006 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 8/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5001 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 9/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5029 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 10/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5012 - val_loss: 0.6931 - val_accuracy: 0.5029
<tensorflow.python.keras.callbacks.History at 0x7f29766804a8>

我还尝试将标签更改为分类并使用categorical_crossentropy,但无济于事。

编辑 2:在正确设置激活的情况下,相同的行为会在更多时期持续存在。

模型:

def get_model(input_shape):
  model = keras.Sequential()
  
  model.add(Conv2D(32, input_shape=input_shape, kernel_size=(3, 3), activation='relu'))
  model.add(MaxPooling2D(pool_size=(2, 2)))

  model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
  model.add(MaxPooling2D(pool_size=(2, 2)))
  
  model.add(Flatten())

  model.add(Dense(64, activation='relu'))
  model.add(Dense(1, activation='sigmoid'))

  return model

输出:

Epoch 1/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6937 - accuracy: 0.4998 - val_loss: 0.6931 - val_accuracy: 0.5008
...
Epoch 243/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5006 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 244/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5007 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 245/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5014 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 246/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5035 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 247/250
1907/1907 [==============================] - 7s 4ms/step - loss: 0.6932 - accuracy: 0.5031 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 248/250
1907/1907 [==============================] - 7s 4ms/step - loss: 0.6932 - accuracy: 0.5026 - val_loss: 0.6932 - val_accuracy: 0.4971
Epoch 249/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5018 - val_loss: 0.6932 - val_accuracy: 0.4971
Epoch 250/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5007 - val_loss: 0.6931 - val_accuracy: 0.5029

数据样本:

display(trainX[0])
display(trainX[0].shape)
---
array([[[-0.81307793, -0.80876915, -0.80270227, -0.81340067],
        [-0.81323822, -0.80901267, -0.80424022, -0.81004681],
        [-0.80974839, -0.80952621, -0.80894936, -0.81924987],
        [-0.81901061, -0.81892894, -0.8198063 , -0.82950191],
        [-0.82926863, -0.82535357, -0.81962295, -0.82940024],
        [-0.82911602, -0.82669005, -0.81815252, -0.82725751],
        [-0.82717653, -0.82594539, -0.81691338, -0.82605227],
        [-0.82584266, -0.82452835, -0.81556359, -0.82556375],
        [-0.82525089, -0.82266387, -0.8177839 , -0.82243512],
        [-0.82222369, -0.82112803, -0.82649334, -0.83150323]],

       [[-0.81323822, -0.80901267, -0.80424022, -0.81004681],
        [-0.81339844, -0.80925606, -0.80577279, -0.80666623],
        [-0.80990994, -0.8097693 , -0.81046532, -0.81594339],
        [-0.81916858, -0.81916656, -0.82128286, -0.82628101],
        [-0.8294225 , -0.82558735, -0.82110019, -0.82617847],
        [-0.82926995, -0.82692302, -0.81963519, -0.82401759],
        [-0.82733125, -0.82617881, -0.8184006 , -0.82280219],
        [-0.82599791, -0.82476263, -0.81705573, -0.82230957],
        [-0.82540639, -0.82289927, -0.81926792, -0.81915487],
        [-0.8223804 , -0.82136435, -0.82794484, -0.82829943]],

       [[-0.80974839, -0.80952621, -0.80894936, -0.81924987],
        [-0.80990994, -0.8097693 , -0.81046532, -0.81594339],
        [-0.80639256, -0.81028192, -0.81510641, -0.82501505],
        [-0.81572868, -0.81966765, -0.82580199, -0.83511534],
        [-0.82607158, -0.82608032, -0.82562142, -0.8350152 ],
        [-0.82591768, -0.82741428, -0.8241732 , -0.83290467],
        [-0.8239619 , -0.82667103, -0.82295269, -0.83171742],
        [-0.82261689, -0.82525665, -0.82162309, -0.83123616],
        [-0.82202021, -0.82339567, -0.82381013, -0.82815378],
        [-0.818968  , -0.82186268, -0.83238638, -0.83708633]],

       [[-0.81901061, -0.81892894, -0.8198063 , -0.82950191],
        [-0.81916858, -0.81916656, -0.82128286, -0.82628101],
        [-0.81572868, -0.81966765, -0.82580199, -0.83511534],
        [-0.82485699, -0.82883834, -0.8362085 , -0.84494163],
        [-0.83496124, -0.83509975, -0.83603291, -0.84484426],
        [-0.83481096, -0.83640177, -0.83462448, -0.84279185],
        [-0.83290099, -0.83567633, -0.83343734, -0.84163708],
        [-0.83158729, -0.83429571, -0.83214391, -0.84116895],
        [-0.83100444, -0.83247886, -0.83427135, -0.83817006],
        [-0.82802254, -0.830982  , -0.84260906, -0.84685789]],

       [[-0.82926863, -0.82535357, -0.81962295, -0.82940024],
        [-0.8294225 , -0.82558735, -0.82110019, -0.82617847],
        [-0.82607158, -0.82608032, -0.82562142, -0.8350152 ],
        [-0.83496124, -0.83509975, -0.83603291, -0.84484426],
        [-0.84479157, -0.84125479, -0.83585723, -0.84474687],
        [-0.84464544, -0.84253437, -0.83444812, -0.84269387],
        [-0.84278804, -0.84182145, -0.8332604 , -0.84153877],
        [-0.84151027, -0.84046456, -0.83196635, -0.84107051],
        [-0.84094331, -0.83867874, -0.83409482, -0.83807077],
        [-0.83804212, -0.83720728, -0.84243663, -0.84676108]],

       [[-0.82911602, -0.82669005, -0.81815252, -0.82725751],
        [-0.82926995, -0.82692302, -0.81963519, -0.82401759],
        [-0.82591768, -0.82741428, -0.8241732 , -0.83290467],
        [-0.83481096, -0.83640177, -0.83462448, -0.84279185],
        [-0.84464544, -0.84253437, -0.83444812, -0.84269387],
        [-0.84449925, -0.84380921, -0.83303354, -0.84062854],
        [-0.84264105, -0.84309893, -0.83184123, -0.83946655],
        [-0.84136274, -0.84174705, -0.8305422 , -0.83899551],
        [-0.84079554, -0.83996778, -0.83267886, -0.83597806],
        [-0.83789312, -0.83850169, -0.84105351, -0.84472027]],

       [[-0.82717653, -0.82594539, -0.81691338, -0.82605227],
        [-0.82733125, -0.82617881, -0.8184006 , -0.82280219],
        [-0.8239619 , -0.82667103, -0.82295269, -0.83171742],
        [-0.83290099, -0.83567633, -0.83343734, -0.84163708],
        [-0.84278804, -0.84182145, -0.8332604 , -0.84153877],
        [-0.84264105, -0.84309893, -0.83184123, -0.83946655],
        [-0.84077276, -0.84238718, -0.83064506, -0.83830071],
        [-0.83948755, -0.84103251, -0.82934186, -0.83782811],
        [-0.8389173 , -0.83924958, -0.83148541, -0.83480076],
        [-0.8359994 , -0.8377805 , -0.83988758, -0.84357199]],

       [[-0.82584266, -0.82452835, -0.81556359, -0.82556375],
        [-0.82599791, -0.82476263, -0.81705573, -0.82230957],
        [-0.82261689, -0.82525665, -0.82162309, -0.83123616],
        [-0.83158729, -0.83429571, -0.83214391, -0.84116895],
        [-0.84151027, -0.84046456, -0.83196635, -0.84107051],
        [-0.84136274, -0.84174705, -0.8305422 , -0.83899551],
        [-0.83948755, -0.84103251, -0.82934186, -0.83782811],
        [-0.83819763, -0.83967254, -0.82803413, -0.83735488],
        [-0.8376253 , -0.83788269, -0.83018513, -0.83432354],
        [-0.83469681, -0.83640794, -0.83861716, -0.84310649]],

       [[-0.82525089, -0.82266387, -0.8177839 , -0.82243512],
        [-0.82540639, -0.82289927, -0.81926792, -0.81915487],
        [-0.82202021, -0.82339567, -0.82381013, -0.82815378],
        [-0.83100444, -0.83247886, -0.83427135, -0.83817006],
        [-0.84094331, -0.83867874, -0.83409482, -0.83807077],
        [-0.84079554, -0.83996778, -0.83267886, -0.83597806],
        [-0.8389173 , -0.83924958, -0.83148541, -0.83480076],
        [-0.8376253 , -0.83788269, -0.83018513, -0.83432354],
        [-0.83705204, -0.83608379, -0.83232385, -0.83126675],
        [-0.83411887, -0.83460162, -0.8407067 , -0.84012427]],

       [[-0.82222369, -0.82112803, -0.82649334, -0.83150323],
        [-0.8223804 , -0.82136435, -0.82794484, -0.82829943],
        [-0.818968  , -0.82186268, -0.83238638, -0.83708633],
        [-0.82802254, -0.830982  , -0.84260906, -0.84685789],
        [-0.83804212, -0.83720728, -0.84243663, -0.84676108],
        [-0.83789312, -0.83850169, -0.84105351, -0.84472027],
        [-0.8359994 , -0.8377805 , -0.83988758, -0.84357199],
        [-0.83469681, -0.83640794, -0.83861716, -0.84310649],
        [-0.83411887, -0.83460162, -0.8407067 , -0.84012427],
        [-0.83116192, -0.83311339, -0.84889275, -0.84876322]]])
(10, 10, 4)

display(trainY[0:5])
display(trainY.shape)
---
array([0, 1, 0, 1, 0], dtype=int64)
(47666,)

标签: pythontensorflowmachine-learningkerasdeep-learning

解决方案


该模型没有学习,因为卷积层具有线性激活函数,None如果您不指定,则默认情况下该函数是线性激活函数。通常与conv层一起使用的激活函数Relu非常简单地添加activation='relu'到您的转换层


推荐阅读