首页 > 解决方案 > 当我使用 model.predict() 解决多类问题时,内核死了

问题描述

我使用 MacPro M1 芯片,python:3.8.10,TensorFlow:2.4.0-rc0。

在标签为 y 的多类问题(softmax)中是一个稀疏数字,即使有一些警告也能很好地工作,但model.fit()会导致内核立即死亡。model.evaluate()model.predict()

但是,如果标签 y 是一个单热向量,则所有 、 和 都适用于多类(softmax model.fit()model.evaluate()model.predict()

顺便说一句,所有model.fit(), model.evaluate(), 和 model.predict()在二元类和回归问题中也能很好地工作。

这是我的代码(导致内核死机)及其结果:

## Define a MLP model
model = keras.models.Sequential([
    keras.layers.Flatten(input_shape = [28,28]),
    keras.layers.Dense(50, activation = "relu"),
    keras.layers.Dense(20, activation = "relu"),
    keras.layers.Dense(10, activation = "softmax")
])

model.compile(loss = "sparse_categorical_crossentropy",
              optimizer = "sgd",
              metrics = ["accuracy"]
             )

history = model.fit(X_train, y_train, epochs = 30,
                   validation_data = (X_valid, y_valid))

print(X_test.shape)
model.evaluate(X_test[:1], y_test[:1])

输出:

(10000, 28, 28)
1/1 [==============================] - 0s 8ms/step - loss: 0.0219 - accuracy: 1.0000

Out[21]:
[0.02185755781829357, 1.0]
X_new = X_test[:1]
print(X_new.shape)
print(X_new.shape[0])
y_prob = model.predict(X_new)

输出:

(1, 28, 28)
1
WARNING:tensorflow:AutoGraph could not transform <function Model.make_predict_function.<locals>.predict_function at 0x165553dc0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: unsupported operand type(s) for -: 'NoneType' and 'int'
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function Model.make_predict_function.<locals>.predict_function at 0x165553dc0> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: unsupported operand type(s) for -: 'NoneType' and 'int'\
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

标签: pythontensorflowmachine-learning

解决方案


推荐阅读